"""Tests for lib/preflight.py — Deterministic pre-flight validation."""

import pytest

from recoil.pipeline._lib.preflight import PreFlightChecker, PreFlightWarning, CostEstimate


# Shared model profiles for cost estimation tests
_TEST_PROFILES = {
    "gemini-3-pro-image-preview": {
        "modality": "image",
        "cost_per_image": 0.134,
    },
    "gemini-3.1-flash-image-preview": {
        "modality": "image",
        "cost_per_image": 0.039,
    },
    "seedream-v4.5": {
        "modality": "image",
        "cost_per_image": 0.04,
    },
    "kling-v3": {
        "modality": "video",
        "cost_per_second": 0.10,
        "cost_per_second_standard": 0.10,
    },
    "seeddance-2.0": {
        "modality": "video",
        "cost_per_second": 0.3034,
    },
    "veo-3.1": {
        "modality": "video",
        "cost_per_second": 0.05,
    },
}


def _make_plan(shots):
    """Wrap a list of shot dicts in a plan."""
    return {"shots": shots}


def _make_shot(shot_id="ep01_s01_001", pipeline="still", model="gemini-3-pro-image-preview", **overrides):
    """Helper to build a shot dict with sensible defaults."""
    shot = {
        "shot_id": shot_id,
        "pipeline": pipeline,
        "model": model,
        "prompt_data": {},
        "asset_data": {},
        "spatial_data": {},
        "routing_data": {},
    }
    shot.update(overrides)
    return shot


@pytest.fixture
def checker():
    """PreFlightChecker with test model profiles (no disk reads)."""
    return PreFlightChecker(model_profiles=_TEST_PROFILES)


# ── Camera Contradiction Detection ───────────────────────────────────


class TestCameraContradiction:
    def test_static_camera_with_pan_in_camera_line(self, checker):
        shot = _make_shot(prompt_data={
            "camera_movement": "static",
            "prompt_skeleton": {"camera_line": "Slow pan across the room"},
        })
        warnings = checker.validate_batch(_make_plan([shot]))
        assert any(w.check == "camera_contradiction" and w.severity == "error" for w in warnings)

    def test_static_camera_with_tilt_detected(self, checker):
        shot = _make_shot(prompt_data={
            "camera_movement": "static",
            "prompt_skeleton": {"camera_line": "Tilt up to reveal ceiling"},
        })
        warnings = checker.validate_batch(_make_plan([shot]))
        assert any(w.check == "camera_contradiction" for w in warnings)

    def test_static_camera_no_motion_words_is_clean(self, checker):
        shot = _make_shot(prompt_data={
            "camera_movement": "static",
            "prompt_skeleton": {"camera_line": "Wide angle, centered framing"},
        })
        warnings = checker.validate_batch(_make_plan([shot]))
        camera_warnings = [w for w in warnings if w.check == "camera_contradiction"]
        assert len(camera_warnings) == 0

    def test_motion_camera_with_static_keyword_warns(self, checker):
        shot = _make_shot(prompt_data={
            "camera_movement": "pan",
            "prompt_skeleton": {"camera_line": "Camera is locked and static"},
        })
        warnings = checker.validate_batch(_make_plan([shot]))
        assert any(
            w.check == "camera_contradiction" and w.severity == "warning"
            for w in warnings
        )

    def test_no_prompt_data_no_crash(self, checker):
        shot = _make_shot(prompt_data={})
        warnings = checker.validate_batch(_make_plan([shot]))
        camera_warnings = [w for w in warnings if w.check == "camera_contradiction"]
        assert len(camera_warnings) == 0


# ── Kling Duration Rounding ──────────────────────────────────────────


class TestKlingDuration:
    def test_kling_odd_duration_warns(self, checker):
        shot = _make_shot(
            model="kling-v3",
            routing_data={"target_editorial_duration_s": 7},
        )
        warnings = checker.validate_batch(_make_plan([shot]))
        assert any(w.check == "kling_duration" for w in warnings)

    def test_kling_5s_is_clean(self, checker):
        shot = _make_shot(
            model="kling-v3",
            routing_data={"target_editorial_duration_s": 5},
        )
        warnings = checker.validate_batch(_make_plan([shot]))
        kling_warnings = [w for w in warnings if w.check == "kling_duration"]
        assert len(kling_warnings) == 0

    def test_kling_10s_is_clean(self, checker):
        shot = _make_shot(
            model="kling-v3",
            routing_data={"target_editorial_duration_s": 10},
        )
        warnings = checker.validate_batch(_make_plan([shot]))
        kling_warnings = [w for w in warnings if w.check == "kling_duration"]
        assert len(kling_warnings) == 0

    def test_non_kling_model_skips_check(self, checker):
        shot = _make_shot(
            model="veo-3.1",
            routing_data={"target_editorial_duration_s": 7},
        )
        warnings = checker.validate_batch(_make_plan([shot]))
        kling_warnings = [w for w in warnings if w.check == "kling_duration"]
        assert len(kling_warnings) == 0


# ── Missing Character Refs ───────────────────────────────────────────


class TestMissingRefs:
    def test_missing_character_ref_directory(self, checker, tmp_path):
        refs_dir = tmp_path / "refs"
        refs_dir.mkdir()
        (refs_dir / "characters").mkdir()
        # No "jinx" subdirectory exists
        shot = _make_shot(asset_data={"characters": [{"char_id": "jinx"}]})
        warnings = checker.validate_batch(_make_plan([shot]), refs_dir=refs_dir)
        assert any(w.check == "missing_ref" and "jinx" in w.message for w in warnings)

    def test_existing_character_ref_is_clean(self, checker, tmp_path):
        refs_dir = tmp_path / "refs"
        (refs_dir / "characters" / "jinx").mkdir(parents=True)
        shot = _make_shot(asset_data={"characters": [{"char_id": "jinx"}]})
        warnings = checker.validate_batch(_make_plan([shot]), refs_dir=refs_dir)
        ref_warnings = [w for w in warnings if w.check == "missing_ref"]
        assert len(ref_warnings) == 0

    def test_empty_characters_no_crash(self, checker, tmp_path):
        refs_dir = tmp_path / "refs"
        refs_dir.mkdir()
        shot = _make_shot(asset_data={"characters": []})
        warnings = checker.validate_batch(_make_plan([shot]), refs_dir=refs_dir)
        ref_warnings = [w for w in warnings if w.check == "missing_ref"]
        assert len(ref_warnings) == 0


# ── Character Count vs Spatial Slots ─────────────────────────────────


class TestCharacterSpatialMismatch:
    def test_mismatch_warns(self, checker):
        shot = _make_shot(
            asset_data={"characters": [{"char_id": "jinx"}, {"char_id": "ava"}]},
            spatial_data={"character_relationships": [{"position": "left"}]},
        )
        warnings = checker.validate_batch(_make_plan([shot]))
        assert any(w.check == "spatial_mismatch" for w in warnings)

    def test_matching_counts_clean(self, checker):
        shot = _make_shot(
            asset_data={"characters": [{"char_id": "jinx"}]},
            spatial_data={"character_relationships": [{"position": "center"}]},
        )
        warnings = checker.validate_batch(_make_plan([shot]))
        spatial_warnings = [w for w in warnings if w.check == "spatial_mismatch"]
        assert len(spatial_warnings) == 0

    def test_no_spatial_data_skips(self, checker):
        shot = _make_shot(
            asset_data={"characters": [{"char_id": "jinx"}]},
            spatial_data={},
        )
        warnings = checker.validate_batch(_make_plan([shot]))
        spatial_warnings = [w for w in warnings if w.check == "spatial_mismatch"]
        assert len(spatial_warnings) == 0


# ── Prompt Length vs Model Limits ────────────────────────────────────


class TestPromptLength:
    def test_exceeding_limit_warns(self, checker):
        # Build a skeleton with fields totaling > 4000 chars
        long_text = "x" * 4100
        shot = _make_shot(prompt_data={
            "prompt_skeleton": {"scene_description": long_text},
        })
        warnings = checker.validate_batch(_make_plan([shot]))
        assert any(w.check == "prompt_length" for w in warnings)

    def test_under_limit_is_clean(self, checker):
        shot = _make_shot(prompt_data={
            "prompt_skeleton": {"scene_description": "A short prompt."},
        })
        warnings = checker.validate_batch(_make_plan([shot]))
        length_warnings = [w for w in warnings if w.check == "prompt_length"]
        assert len(length_warnings) == 0

    def test_string_skeleton_handled(self, checker):
        shot = _make_shot(prompt_data={
            "prompt_skeleton": "x" * 5000,
        })
        warnings = checker.validate_batch(_make_plan([shot]))
        assert any(w.check == "prompt_length" for w in warnings)


# ── Required Fields per Pipeline ─────────────────────────────────────


class TestRequiredFields:
    def test_i2v_missing_prompt_data_errors(self, checker):
        shot = _make_shot(pipeline="i2v", prompt_data={}, asset_data={"characters": []})
        # prompt_data is present but falsy (empty dict evaluates to false)
        warnings = checker.validate_batch(_make_plan([shot]))
        field_errors = [w for w in warnings if w.check == "missing_field"]
        assert any("prompt_data" in w.message for w in field_errors)

    def test_i2v_missing_asset_data_errors(self, checker):
        shot = _make_shot(pipeline="i2v", prompt_data={"key": "value"})
        shot["asset_data"] = {}
        warnings = checker.validate_batch(_make_plan([shot]))
        field_errors = [w for w in warnings if w.check == "missing_field"]
        assert any("asset_data" in w.message for w in field_errors)

    def test_still_missing_prompt_data_errors(self, checker):
        shot = _make_shot(pipeline="still")
        shot["prompt_data"] = {}
        warnings = checker.validate_batch(_make_plan([shot]))
        field_errors = [w for w in warnings if w.check == "missing_field"]
        assert any("prompt_data" in w.message for w in field_errors)

    def test_all_required_fields_present_clean(self, checker):
        shot = _make_shot(
            pipeline="i2v",
            prompt_data={"key": "value"},
            asset_data={"characters": [{"char_id": "jinx"}]},
        )
        warnings = checker.validate_batch(_make_plan([shot]))
        field_errors = [w for w in warnings if w.check == "missing_field"]
        assert len(field_errors) == 0


# ── Cost Estimation ──────────────────────────────────────────────────


class TestCostEstimation:
    def test_still_pipeline_cost(self, checker):
        shot = _make_shot(pipeline="still")
        est = checker.estimate_cost(_make_plan([shot]))
        # previs (Flash 3.1 = 0.039) + keyframe (Seedream v4.5 = 0.04) + QC (2 * 0.039)
        expected_previs = 0.039
        expected_keyframe = 0.04
        expected_qc = 0.039 * 2
        expected_total = expected_previs + expected_keyframe + expected_qc
        assert abs(est.total - expected_total) < 1e-6
        assert est.video == 0.0

    def test_i2v_pipeline_includes_video_cost(self, checker):
        shot = _make_shot(
            pipeline="i2v",
            model="kling-v3",
            routing_data={"target_editorial_duration_s": 5},
        )
        est = checker.estimate_cost(_make_plan([shot]))
        # Video: 0.10 * 5 = 0.50
        assert est.video == 0.50
        assert est.total > 0.50  # includes previs + keyframe + qc

    def test_per_shot_breakdown(self, checker):
        shots = [
            _make_shot(shot_id="s1", pipeline="still"),
            _make_shot(shot_id="s2", pipeline="i2v", model="kling-v3",
                       routing_data={"target_editorial_duration_s": 5}),
        ]
        est = checker.estimate_cost(_make_plan(shots))
        assert len(est.per_shot) == 2
        assert est.per_shot[0]["shot_id"] == "s1"
        assert est.per_shot[0]["video"] == 0
        assert est.per_shot[1]["video"] == 0.50

    def test_unknown_model_zero_video_cost(self, checker):
        """Video cost for an unknown model is 0 (no profile entry)."""
        shot = _make_shot(
            pipeline="i2v",
            model="unknown-model-9000",
            routing_data={"target_editorial_duration_s": 10},
        )
        est = checker.estimate_cost(_make_plan([shot]))
        # Video cost should be 0 because the model has no profile
        assert est.video == 0.0
        # But previs/keyframe/QC still calculated from known models
        assert est.previs > 0
        assert est.qc > 0


# ── Empty Manifest ───────────────────────────────────────────────────


class TestEmptyPlan:
    def test_validate_empty_plan(self, checker):
        warnings = checker.validate_batch({"shots": []})
        assert warnings == []

    def test_validate_missing_shots_key(self, checker):
        warnings = checker.validate_batch({})
        assert warnings == []

    def test_estimate_cost_empty_plan(self, checker):
        est = checker.estimate_cost({"shots": []})
        assert est.total == 0.0
        assert est.per_shot == []


# ── Multiple Warnings on Single Shot ─────────────────────────────────


class TestMultipleWarnings:
    def test_single_shot_multiple_issues(self, checker, tmp_path):
        refs_dir = tmp_path / "refs"
        (refs_dir / "characters").mkdir(parents=True)
        # Shot with: camera contradiction + missing ref + spatial mismatch + kling duration
        shot = _make_shot(
            shot_id="problem_shot",
            pipeline="i2v",
            model="kling-v3",
            prompt_data={
                "camera_movement": "static",
                "prompt_skeleton": {"camera_line": "Slow pan left to right"},
            },
            asset_data={"characters": [{"char_id": "jinx"}, {"char_id": "ava"}]},
            spatial_data={"character_relationships": [{"position": "left"}]},
            routing_data={"target_editorial_duration_s": 7},
        )
        warnings = checker.validate_batch(_make_plan([shot]), refs_dir=refs_dir)
        checks_found = {w.check for w in warnings}
        assert "camera_contradiction" in checks_found
        assert "missing_ref" in checks_found
        assert "spatial_mismatch" in checks_found
        assert "kling_duration" in checks_found
        # At least 4 distinct warnings (could be more — 2 missing refs)
        assert len(warnings) >= 4


# ── I2V Keyframe Check ──────────────────────────────────────────────


class TestI2VKeyframe:
    def test_i2v_match_cut_without_keyframe_warns(self, checker):
        shot = _make_shot(
            pipeline="i2v",
            prompt_data={"key": "val"},
            asset_data={"characters": [{"char_id": "jinx"}]},
            routing_data={"narrative_requires_match_cut": True},
        )
        # No keyframe_ref set
        warnings = checker.validate_batch(_make_plan([shot]))
        assert any(w.check == "i2v_no_keyframe" for w in warnings)

    def test_i2v_no_match_cut_clean(self, checker):
        shot = _make_shot(
            pipeline="i2v",
            prompt_data={"key": "val"},
            asset_data={"characters": [{"char_id": "jinx"}]},
            routing_data={"narrative_requires_match_cut": False},
        )
        warnings = checker.validate_batch(_make_plan([shot]))
        kf_warnings = [w for w in warnings if w.check == "i2v_no_keyframe"]
        assert len(kf_warnings) == 0
