"""
preflight.py — Deterministic PreFlightChecker for batch launch validation.

Zero-cost Python checks before committing API spend.
Catches prompt contradictions, missing references, budget overruns, and
structural issues that would waste generation credits.

ADR-R06: PreFlightChecker.
"""

import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional

from recoil.core.paths import PIPELINE_ROOT, ProjectPaths
from recoil.core.model_profiles import get_model

logger = logging.getLogger(__name__)

# Load model profiles for cost estimation
_PROFILES_PATH = PIPELINE_ROOT / "config" / "model_profiles.json"


def _load_profiles() -> dict:
    """Reuses the canonical core.model_profiles cache.

    Falls back to empty dict on FileNotFoundError to preserve the legacy
    fail-soft contract for tests that stub PreFlightChecker without a
    config tree on disk.
    """
    try:
        from recoil.core.model_profiles import load

        return load()
    except FileNotFoundError:
        return {}


@dataclass
class PreFlightWarning:
    """A single pre-flight check result."""

    shot_id: str
    severity: str  # "error" | "warning"
    check: str  # e.g. "camera_contradiction", "missing_ref"
    message: str


@dataclass
class CostEstimate:
    """Cost breakdown for a batch."""

    total: float = 0.0
    previs: float = 0.0
    keyframes: float = 0.0
    video: float = 0.0
    qc: float = 0.0
    per_shot: list = field(default_factory=list)


# Camera movements that imply motion
_MOTION_CAMERAS = {
    "pan",
    "tilt",
    "push_in",
    "pull_back",
    "tracking",
    "crane",
    "handheld",
    "steadicam",
    "dolly",
}

# Shot types that are close-up
_CLOSE_SHOT_TYPES = {"MCU", "CU", "ECU", "INSERT"}

# Required fields per pipeline type
_REQUIRED_FIELDS = {
    "still": ["prompt_data"],
    "i2v": ["prompt_data", "asset_data"],
    "t2v": ["prompt_data", "asset_data"],
    "multi_shot": ["prompt_data", "asset_data"],
}


class PreFlightChecker:
    """Deterministic pre-flight validation for batch launches.

    All checks are local Python — no API calls, zero cost.
    """

    def __init__(self, model_profiles: Optional[dict] = None):
        self.profiles = model_profiles or _load_profiles()

    def validate_batch(
        self,
        plan: dict,
        refs_dir: Optional[Path] = None,
        project: Optional[str] = None,
    ) -> list[PreFlightWarning]:
        """Run all pre-flight checks on a plan.

        Args:
            plan: Episode plan dict with shots[].
            refs_dir: Path to refs directory for existence checks.
            project: Project slug to resolve ProjectPaths against. If None,
                falls back to DEFAULT_PROJECT (legacy behavior).

        Returns:
            List of PreFlightWarning objects. Empty = all clear.
        """
        warnings = []
        shots = plan.get("shots", [])

        # Cross-shot spatial checks (run once on full list)
        warnings.extend(self._check_180_rule(shots))
        warnings.extend(self._check_lighting_scene_consistency(shots))
        warnings.extend(self._check_punch_in_refs(shots))
        warnings.extend(self._check_cross_episode_continuity(plan))

        # Per-shot checks
        for shot in shots:
            shot_id = shot.get("shot_id", "unknown")
            warnings.extend(self._check_camera_contradiction(shot_id, shot))
            warnings.extend(self._check_kling_duration(shot_id, shot))
            warnings.extend(
                self._check_missing_refs(shot_id, shot, refs_dir, project)
            )
            warnings.extend(self._check_character_spatial_mismatch(shot_id, shot))
            warnings.extend(self._check_prompt_length(shot_id, shot))
            warnings.extend(self._check_required_fields(shot_id, shot))
            warnings.extend(self._check_i2v_needs_keyframe(shot_id, shot))

        return warnings

    def estimate_cost(self, plan: dict) -> CostEstimate:
        """Estimate total API cost for a plan.

        Uses model_profiles.json costs. Returns breakdown by stage.
        """
        est = CostEstimate()
        shots = plan.get("shots", [])

        for shot in shots:
            shot_id = shot.get("shot_id", "unknown")
            routing = shot.get("routing_data", {})
            pipeline = shot.get("pipeline", "still")
            model = shot.get("model", get_model("production", "image"))
            duration = routing.get("target_editorial_duration_s", 5)

            shot_cost = {
                "shot_id": shot_id,
                "previs": 0,
                "keyframe": 0,
                "video": 0,
                "qc": 0,
            }

            # Previs cost (Flash 3.1)
            previs_cost = self._get_image_cost(get_model("exploration", "image"))
            shot_cost["previs"] = previs_cost
            est.previs += previs_cost

            # Keyframe cost (NBP)
            keyframe_cost = self._get_image_cost(get_model("production", "image"))
            shot_cost["keyframe"] = keyframe_cost
            est.keyframes += keyframe_cost

            # Video cost (if pipeline needs it)
            if pipeline in ("i2v", "t2v", "multi_shot"):
                video_cost = self._get_video_cost(model, duration)
                shot_cost["video"] = video_cost
                est.video += video_cost

            # QC cost (Flash for gates)
            qc_cost = 0.039 * 2  # Gate 1 + Gate 2 at minimum
            shot_cost["qc"] = qc_cost
            est.qc += qc_cost

            shot_cost["total"] = sum(v for k, v in shot_cost.items() if k != "shot_id")
            est.per_shot.append(shot_cost)

        est.total = est.previs + est.keyframes + est.video + est.qc
        return est

    # ── Individual Checks ─────────────────────────────────────────

    def _check_camera_contradiction(
        self, shot_id: str, shot: dict
    ) -> list[PreFlightWarning]:
        """Check for STATIC camera with motion descriptors."""
        prompt_data = shot.get("prompt_data", {})
        camera_movement = prompt_data.get("camera_movement", "static")
        skeleton = prompt_data.get("prompt_skeleton", {})
        camera_line = (
            skeleton.get("camera_line", "") if isinstance(skeleton, dict) else ""
        )

        warnings = []

        # Static camera but prompt mentions movement
        if camera_movement == "static" and camera_line:
            motion_words = {
                "pan",
                "tilt",
                "track",
                "dolly",
                "crane",
                "push",
                "pull",
                "sweep",
                "arc",
            }
            lower_camera = camera_line.lower()
            for word in motion_words:
                if word in lower_camera:
                    warnings.append(
                        PreFlightWarning(
                            shot_id=shot_id,
                            severity="error",
                            check="camera_contradiction",
                            message=f"Camera is STATIC but camera_line contains motion: '{word}' in '{camera_line}'",
                        )
                    )
                    break

        # Motion camera but prompt says static/locked/fixed
        if camera_movement in _MOTION_CAMERAS:
            static_words = {"static", "locked", "fixed", "tripod", "stationary"}
            lower_camera = camera_line.lower()
            for word in static_words:
                if word in lower_camera:
                    warnings.append(
                        PreFlightWarning(
                            shot_id=shot_id,
                            severity="warning",
                            check="camera_contradiction",
                            message=f"Camera is {camera_movement} but camera_line says '{word}'",
                        )
                    )
                    break

        return warnings

    def _check_kling_duration(self, shot_id: str, shot: dict) -> list[PreFlightWarning]:
        """Kling only accepts 5s or 10s — flag odd durations."""
        model = shot.get("model", "")
        if "kling" not in model:
            return []

        routing = shot.get("routing_data", {})
        duration = routing.get("target_editorial_duration_s", 5)

        if duration not in (5, 10):
            return [
                PreFlightWarning(
                    shot_id=shot_id,
                    severity="warning",
                    check="kling_duration",
                    message=f"Kling duration {duration}s will be rounded to {10 if duration > 5 else 5}s",
                )
            ]
        return []

    def _check_missing_refs(
        self,
        shot_id: str,
        shot: dict,
        refs_dir: Optional[Path],
        project: Optional[str] = None,
    ) -> list[PreFlightWarning]:
        """Check that character refs exist on disk.

        Under the v3 layout, character refs live at
        ``assets/char/<slug>/`` (resolved via ``ProjectPaths.asset_subject_dir``).
        Legacy callers that still pass a ``refs_dir`` are honored — they
        are assumed to be a v1 ``output/refs/`` root and probed at
        ``<refs_dir>/characters/<char_id>``.
        """
        warnings = []
        asset_data = shot.get("asset_data", {})
        characters = asset_data.get("characters", [])

        for char in characters:
            char_id = char.get("char_id", "") if isinstance(char, dict) else str(char)
            if not char_id:
                continue

            # v3 layout: assets/char/<slug>/ via ProjectPaths
            if refs_dir is None:
                paths = ProjectPaths.for_project(project)
                char_ref_dir = paths.asset_subject_dir("char", char_id)
            else:
                # Legacy explicit refs_dir (treat as v1 output/refs/ root)
                char_ref_dir = refs_dir / "characters" / char_id
            if not char_ref_dir.exists():
                # Also check the Recoil refs path
                from recoil.core.paths import RECOIL_ROOT, DEFAULT_PROJECT

                proj_slug = project or DEFAULT_PROJECT
                recoil_refs = (
                    RECOIL_ROOT / proj_slug / "refs" / "characters" / char_id
                )
                if not recoil_refs.exists():
                    warnings.append(
                        PreFlightWarning(
                            shot_id=shot_id,
                            severity="error",
                            check="missing_ref",
                            message=f"No reference images found for character '{char_id}'",
                        )
                    )

        return warnings

    def _check_character_spatial_mismatch(
        self, shot_id: str, shot: dict
    ) -> list[PreFlightWarning]:
        """Check character count vs spatial slot count."""
        asset_data = shot.get("asset_data", {})
        spatial_data = shot.get("spatial_data", {})

        char_count = len(asset_data.get("characters", []))
        relationships = spatial_data.get("character_relationships", [])
        spatial_slots = len(relationships) if relationships else 0

        if char_count > 0 and spatial_slots > 0 and char_count != spatial_slots:
            return [
                PreFlightWarning(
                    shot_id=shot_id,
                    severity="warning",
                    check="spatial_mismatch",
                    message=f"{char_count} characters but {spatial_slots} spatial slots",
                )
            ]
        return []

    def _check_prompt_length(self, shot_id: str, shot: dict) -> list[PreFlightWarning]:
        """Check prompt length against model limits."""
        prompt_data = shot.get("prompt_data", {})
        skeleton = prompt_data.get("prompt_skeleton", {})

        if isinstance(skeleton, dict):
            # Estimate total prompt length from skeleton fields
            total_chars = sum(len(str(v)) for v in skeleton.values())
        else:
            total_chars = len(str(skeleton))

        model = shot.get("model", "")
        profile = self.profiles.get(model, {})
        # Gemini models handle ~4000 chars comfortably
        max_chars = 4000

        if total_chars > max_chars:
            return [
                PreFlightWarning(
                    shot_id=shot_id,
                    severity="warning",
                    check="prompt_length",
                    message=f"Estimated prompt {total_chars} chars exceeds {max_chars} char limit for {model}",
                )
            ]
        return []

    def _check_required_fields(
        self, shot_id: str, shot: dict
    ) -> list[PreFlightWarning]:
        """Check pipeline-specific required fields are present."""
        pipeline = shot.get("pipeline", "still")
        required = _REQUIRED_FIELDS.get(pipeline, [])

        warnings = []
        for field_name in required:
            if not shot.get(field_name):
                warnings.append(
                    PreFlightWarning(
                        shot_id=shot_id,
                        severity="error",
                        check="missing_field",
                        message=f"Pipeline '{pipeline}' requires '{field_name}' but it's missing",
                    )
                )
        return warnings

    def _check_i2v_needs_keyframe(
        self, shot_id: str, shot: dict
    ) -> list[PreFlightWarning]:
        """I2V pipeline needs a keyframe reference image."""
        pipeline = shot.get("pipeline", "")
        if pipeline != "i2v":
            return []

        routing = shot.get("routing_data", {})
        if routing.get("narrative_requires_match_cut") and not shot.get("keyframe_ref"):
            return [
                PreFlightWarning(
                    shot_id=shot_id,
                    severity="warning",
                    check="i2v_no_keyframe",
                    message="I2V pipeline requires keyframe ref for match cut — will be generated in previs stage",
                )
            ]
        return []

    # ── Spatial Continuity Checks ────────────────────────────────

    def _check_180_rule(self, shots: list[dict]) -> list[PreFlightWarning]:
        """Check for 180-degree rule violations within scenes.

        Rule: If characters A and B are LEFT/RIGHT in shot N, they must
        remain LEFT/RIGHT in shot N+1 UNLESS camera_side flips A<->B.
        """
        warnings = []
        scenes: dict[int, list[dict]] = {}
        for shot in shots:
            scene_idx = shot.get("scene_index", 0)
            scenes.setdefault(scene_idx, []).append(shot)

        for scene_idx, scene_shots in scenes.items():
            for i in range(1, len(scene_shots)):
                prev = scene_shots[i - 1]
                curr = scene_shots[i]

                prev_spatial = prev.get("spatial_data", {})
                curr_spatial = curr.get("spatial_data", {})
                prev_side = prev_spatial.get("camera_side", "A")
                curr_side = curr_spatial.get("camera_side", "A")
                side_changed = prev_side != curr_side

                prev_positions = self._extract_char_positions(prev)
                curr_positions = self._extract_char_positions(curr)
                shared_chars = set(prev_positions.keys()) & set(curr_positions.keys())
                if len(shared_chars) < 2:
                    continue

                shared_list = sorted(shared_chars)
                for j in range(len(shared_list) - 1):
                    a, b = shared_list[j], shared_list[j + 1]
                    prev_order = self._position_order(
                        prev_positions[a]
                    ) < self._position_order(prev_positions[b])
                    curr_order = self._position_order(
                        curr_positions[a]
                    ) < self._position_order(curr_positions[b])

                    if prev_order != curr_order and not side_changed:
                        warnings.append(
                            PreFlightWarning(
                                shot_id=curr.get("shot_id", "?"),
                                severity="error",
                                check="180_rule_violation",
                                message=(
                                    f"Characters {a} and {b} swap positions between "
                                    f"{prev.get('shot_id')} and {curr.get('shot_id')} "
                                    f"without camera_side change (both Side {curr_side})"
                                ),
                            )
                        )
        return warnings

    def _check_lighting_scene_consistency(
        self, shots: list[dict]
    ) -> list[PreFlightWarning]:
        """Check that lighting direction is consistent within a scene."""
        warnings = []
        scenes: dict[int, list[dict]] = {}
        for shot in shots:
            scene_idx = shot.get("scene_index", 0)
            scenes.setdefault(scene_idx, []).append(shot)

        for scene_idx, scene_shots in scenes.items():
            directions = set()
            for shot in scene_shots:
                lighting = shot.get("prompt_data", {}).get("lighting", {})
                sources = lighting.get("sources", [])
                if sources:
                    dominant = sources[
                        min(lighting.get("dominant_source_index", 0), len(sources) - 1)
                    ]
                    direction = dominant.get("direction", "")
                    if direction and direction not in (
                        "ABOVE",
                        "BELOW",
                        "SELF_ILLUMINATED",
                    ):
                        directions.add(direction)

            if len(directions) > 1:
                warnings.append(
                    PreFlightWarning(
                        shot_id=f"scene_{scene_idx}",
                        severity="warning",
                        check="lighting_inconsistency",
                        message=(
                            f"Scene {scene_idx} has mixed lighting directions: {directions}. "
                            f"This may cause spatial continuity breaks across cuts."
                        ),
                    )
                )
        return warnings

    def _check_punch_in_refs(self, shots: list[dict]) -> list[PreFlightWarning]:
        """Warn if ECU/BCU follows a wider shot but has no punch_in_from field."""
        warnings = []
        scenes: dict[int, list[dict]] = {}
        for shot in shots:
            scene_idx = shot.get("scene_index", 0)
            scenes.setdefault(scene_idx, []).append(shot)

        _wider_types = {"WS", "WIDE", "EWS", "LS", "MS", "MCU"}
        for scene_idx, scene_shots in scenes.items():
            for i, shot in enumerate(scene_shots):
                shot_type = shot.get("prompt_data", {}).get("shot_type", "")
                if shot_type not in ("ECU", "BCU"):
                    continue

                spatial = shot.get("spatial_data", {})
                if spatial.get("punch_in_from"):
                    continue

                if i > 0:
                    prev_type = (
                        scene_shots[i - 1].get("prompt_data", {}).get("shot_type", "")
                    )
                    if prev_type in _wider_types:
                        curr_chars = {
                            c.get("char_id") if isinstance(c, dict) else c
                            for c in shot.get("asset_data", {}).get("characters", [])
                        }
                        prev_chars = {
                            c.get("char_id") if isinstance(c, dict) else c
                            for c in scene_shots[i - 1]
                            .get("asset_data", {})
                            .get("characters", [])
                        }
                        if curr_chars & prev_chars:
                            warnings.append(
                                PreFlightWarning(
                                    shot_id=shot.get("shot_id", "?"),
                                    severity="warning",
                                    check="punch_in_no_ref",
                                    message=(
                                        f"{shot_type} follows {prev_type} ({scene_shots[i - 1].get('shot_id')}) "
                                        f"with shared characters but no punch_in_from field."
                                    ),
                                )
                            )
        return warnings

    def _check_cross_episode_continuity(self, plan: dict) -> list[PreFlightWarning]:
        """Check that episode-opening shots have spatial data for continuity."""
        warnings = []
        shots = plan.get("shots", [])
        episode = plan.get("episode", 1)

        if episode <= 1 or not shots:
            return warnings

        first_shot = shots[0]
        spatial = first_shot.get("spatial_data", {})

        if not spatial.get("camera_side"):
            warnings.append(
                PreFlightWarning(
                    shot_id=first_shot.get("shot_id", "?"),
                    severity="warning",
                    check="cross_episode_no_spatial",
                    message=(
                        f"First shot of episode {episode} has no camera_side. "
                        f"Spatial continuity with Ep {episode - 1} may break."
                    ),
                )
            )
        return warnings

    def _extract_char_positions(self, shot: dict) -> dict[str, str]:
        """Extract {char_id: screen_position} from a shot."""
        chars = shot.get("asset_data", {}).get("characters", [])
        positions = {}
        for c in chars:
            if isinstance(c, dict):
                positions[c.get("char_id", "")] = c.get("screen_position", "center")
        return positions

    @staticmethod
    def _position_order(pos: str) -> int:
        """Convert position to sortable integer: left=0, center=1, right=2."""
        return {
            "left": 0,
            "center-left": 0,
            "center": 1,
            "center-right": 2,
            "right": 2,
        }.get(pos, 1)

    # ── Cost Helpers ──────────────────────────────────────────────

    def _get_image_cost(self, model_id: str) -> float:
        """Look up per-image cost from profiles."""
        profile = self.profiles.get(model_id, {})
        return profile.get("cost_per_image", 0.0)

    def _get_video_cost(self, model_id: str, duration_s: int) -> float:
        """Look up per-second cost from profiles."""
        profile = self.profiles.get(model_id, {})
        cost_per_sec = profile.get(
            "cost_per_second", profile.get("cost_per_second_standard", 0.0)
        )
        return cost_per_sec * duration_s
