"""
validation.py — 4-Gate validation framework for Starsend visual production.

Gate 0: Character ref set internal consistency ($0.001/check)
Gate 1: Mechanical artifact detection — image or video ($0.001/check)
Gate 2: Semantic keyframe check — identity/wardrobe/composition ($0.00014/check)
Gate 3: Video identity drift spot-check at 25%, 50%, 75% ($0.003/clip)

All gates use Flash 3.1 for cost efficiency. Total QC cost: ~$0.34/episode.

Gate 2 uses binary JSON format (pass/fail per dimension), NOT 1-10 scoring.

Retry logic:
  - Transient (429, 503, timeout): Retry 3x with exponential backoff
  - Gate 1 mechanical fail: Retry generation 2x, then status: failed
  - Gate 2 semantic fail: NO auto-retry → keyframe_rejected → human review
  - Gate 3 video drift: Flag for human review, do not auto-reject

Ported from Recoil's visual_gate.py (Gate 1), visual_qc.py (Gates 0, 3),
and gemini_qc.py (Gate 2 prompt template).
"""

import json
import logging
import os
import sys
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import Optional

_PROJECT_ROOT = Path(__file__).parent.parent
if str(_PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(_PROJECT_ROOT))

try:
    from tenacity import (
        retry,
        stop_after_attempt,
        wait_exponential,
        retry_if_exception_type,
    )

    _HAS_TENACITY = True
except ImportError:
    _HAS_TENACITY = False

logger = logging.getLogger(__name__)

from recoil.core.model_profiles import get_model

# Gate model — Flash 3.1 for all gates (cost efficiency)
GATE_MODEL = get_model("gate_image", "qc")
GATE_COST_PER_CALL = 0.039  # Flash image cost

# Cheaper text-only calls for binary JSON gates
GATE_TEXT_MODEL = get_model("gate_text", "qc")
GATE_TEXT_COST = 0.001


@dataclass
class ValidationResult:
    """Result of a single gate validation."""

    gate: str  # "gate_0", "gate_1_image", "gate_1_video", "gate_2", "gate_3"
    passed: bool
    details: dict  # Gate-specific results
    model: str = ""  # Which model ran the check
    cost: float = 0.0  # API cost of this check


class Validator:
    """4-gate validation framework.

    All gates use Flash 3.1 for cost efficiency.
    Lazy-initializes API client on first use.
    """

    def __init__(self, api_key: Optional[str] = None):
        self._api_key = api_key
        self._client = None

    def _get_client(self):
        """Lazy-initialize Gemini client."""
        if self._client is None:
            from google import genai

            key = (
                self._api_key
                or os.environ.get("GEMINI_API_KEY")
                or os.environ.get("GOOGLE_API_KEY")
            )
            if not key:
                raise RuntimeError("GEMINI_API_KEY or GOOGLE_API_KEY not set")
            self._client = genai.Client(api_key=key)
        return self._client

    def _call_flash(
        self,
        prompt: str,
        images: Optional[list] = None,
        response_schema: Optional[dict] = None,
    ) -> str:
        """Call Flash 3.1 with optional images and response schema.

        Uses tenacity retry for transient API errors (429, 500, 503).
        Lets exceptions propagate to callers (fail-closed).
        """
        from google.genai import types as genai_types

        client = self._get_client()

        config = genai_types.GenerateContentConfig(
            temperature=0.1,  # Low temp for consistent QC
            max_output_tokens=4096,
        )
        if response_schema:
            config.response_mime_type = "application/json"
            config.response_json_schema = response_schema

        # Build contents
        contents = []
        if images:
            for img_data in images:
                contents.append(
                    genai_types.Part.from_bytes(data=img_data, mime_type="image/png")
                )
        contents.append(prompt)

        def _do_call():
            response = client.models.generate_content(
                model=GATE_MODEL,
                contents=contents,
                config=config,
            )
            return response.text if hasattr(response, "text") else str(response)

        if _HAS_TENACITY:
            try:
                from google.api_core.exceptions import (
                    ResourceExhausted,
                    InternalServerError,
                    ServiceUnavailable,
                )

                _do_call = retry(
                    stop=stop_after_attempt(3),
                    wait=wait_exponential(multiplier=2, min=2, max=30),
                    retry=retry_if_exception_type(
                        (ResourceExhausted, InternalServerError, ServiceUnavailable)
                    ),
                )(_do_call)
            except ImportError:
                logger.debug("google.api_core not available — retry disabled")

        return _do_call()

    def _call_flash_parts(
        self,
        parts: list,
        response_schema: Optional[dict] = None,
    ) -> str:
        """Call Flash 3.1 with pre-built interleaved content parts.

        Unlike _call_flash, this accepts a pre-composed list of text strings
        and image bytes, preserving the interleaving order for proper
        attention anchoring.
        """
        from google.genai import types as genai_types

        client = self._get_client()

        config = genai_types.GenerateContentConfig(
            temperature=0.1,
            max_output_tokens=4096,
        )
        if response_schema:
            config.response_mime_type = "application/json"
            config.response_json_schema = response_schema

        # Convert parts to genai Part objects
        contents = []
        for part in parts:
            if isinstance(part, str):
                contents.append(part)
            elif isinstance(part, bytes):
                contents.append(
                    genai_types.Part.from_bytes(data=part, mime_type="image/png")
                )
            elif isinstance(part, Path):
                if part.exists():
                    contents.append(
                        genai_types.Part.from_bytes(
                            data=part.read_bytes(), mime_type="image/png"
                        )
                    )
                else:
                    logger.warning("_call_flash_parts: skipping missing file %s", part)
            else:
                contents.append(part)  # Already a Part object

        def _do_call():
            response = client.models.generate_content(
                model=GATE_MODEL,
                contents=contents,
                config=config,
            )
            return response.text if hasattr(response, "text") else str(response)

        if _HAS_TENACITY:
            try:
                from google.api_core.exceptions import (
                    ResourceExhausted,
                    InternalServerError,
                    ServiceUnavailable,
                )

                _do_call = retry(
                    stop=stop_after_attempt(3),
                    wait=wait_exponential(multiplier=2, min=2, max=30),
                    retry=retry_if_exception_type(
                        (ResourceExhausted, InternalServerError, ServiceUnavailable)
                    ),
                )(_do_call)
            except ImportError:
                pass

        return _do_call()

    # ── Gate 0: Character Ref Set Consistency ─────────────────────

    def run_gate_0(
        self,
        character: str,
        ref_paths: list[Path],
    ) -> ValidationResult:
        """Gate 0: Check internal consistency of a character ref set.

        Validates:
        - Same person across all angles
        - Consistent wardrobe/accessories
        - No background contamination
        - No anatomical artifacts

        Cost: ~$0.039 (1 Flash call with multiple ref images)
        """
        # Load ref images
        images = []
        for p in ref_paths:
            if p.exists():
                images.append(p.read_bytes())

        if len(images) < 2:
            return ValidationResult(
                gate="gate_0",
                passed=False,
                details={"error": f"Need at least 2 refs, got {len(images)}"},
                cost=0.0,
            )

        prompt = (
            f"You are a visual QC inspector for character reference sheets.\n\n"
            f"Character: {character}\n"
            f"These {len(images)} images should all show the SAME character "
            f"from different angles.\n\n"
            f"Check for:\n"
            f"1. IDENTITY: Same person in all images (face, build, ethnicity)\n"
            f"2. WARDROBE: Same outfit, accessories, props in all images\n"
            f"3. ARTIFACTS: No extra limbs, deformed hands, floating objects\n"
            f"4. BACKGROUND: Clean white/neutral background, no contamination\n\n"
            f"Respond with JSON. Each check is pass/fail with a reason."
        )

        schema = {
            "type": "object",
            "properties": {
                "identity_consistent": {
                    "type": "object",
                    "properties": {
                        "pass": {"type": "boolean"},
                        "reason": {"type": "string"},
                    },
                    "required": ["pass", "reason"],
                },
                "wardrobe_consistent": {
                    "type": "object",
                    "properties": {
                        "pass": {"type": "boolean"},
                        "reason": {"type": "string"},
                    },
                    "required": ["pass", "reason"],
                },
                "no_artifacts": {
                    "type": "object",
                    "properties": {
                        "pass": {"type": "boolean"},
                        "reason": {"type": "string"},
                    },
                    "required": ["pass", "reason"],
                },
                "clean_background": {
                    "type": "object",
                    "properties": {
                        "pass": {"type": "boolean"},
                        "reason": {"type": "string"},
                    },
                    "required": ["pass", "reason"],
                },
            },
            "required": [
                "identity_consistent",
                "wardrobe_consistent",
                "no_artifacts",
                "clean_background",
            ],
        }

        try:
            raw = self._call_flash(prompt, images=images, response_schema=schema)
            result = json.loads(raw)
            required = schema["required"]
            checks = [result.get(k) for k in required] if isinstance(result, dict) else []
            if not isinstance(result, dict) or not all(
                isinstance(c, dict) and isinstance(c.get("pass"), bool) for c in checks
            ):
                return ValidationResult(
                    gate="gate_0",
                    passed=False,
                    details={
                        "error": "gate response missing required checks",
                        "flagged_for_review": True,
                        "raw": result,
                    },
                    model=GATE_MODEL,
                    cost=GATE_COST_PER_CALL,
                )
            passed = all(c["pass"] is True for c in checks)
            return ValidationResult(
                gate="gate_0",
                passed=passed,
                details=result,
                model=GATE_MODEL,
                cost=GATE_COST_PER_CALL,
            )
        except Exception as e:
            logger.error("Gate 0 failed: %s", e)
            return ValidationResult(
                gate="gate_0",
                passed=False,
                details={"error": str(e)},
                model=GATE_MODEL,
                cost=GATE_COST_PER_CALL,
            )

    # ── Gate 1 (text): Pre-gen Character Accuracy ───────────────────

    def run_gate_1_text(self, authored_prompt: str, shot: dict) -> ValidationResult:
        """Gate 1 (text): Pre-gen character accuracy check on authored prompt.

        Uses text-only Flash call with Extractive CoT to check the authored
        prompt against the shot specification for character mismatches.
        Micro-cost check (~$0.00005) — run BEFORE committing to image generation.

        Returns ValidationResult with extras/missing in details.
        """
        shot_id = shot.get("shot_id", "unknown")
        asset_data = shot.get("asset_data", {})
        shot_chars = asset_data.get("characters", [])
        character_list = (
            ", ".join(
                c.get("char_id", str(c)) if isinstance(c, dict) else str(c)
                for c in shot_chars
            )
            or "none specified"
        )
        location_id = asset_data.get("location_id", shot.get("location_id", "unknown"))
        action_line = shot.get("action", shot.get("description", ""))

        gate_prompt = f"""You are a script supervisor checking a previz prompt for accuracy.

SHOT SPECIFICATION:
- Shot ID: {shot_id}
- Characters in this shot: {character_list}
- Location: {location_id}
- Action: {action_line}

AUTHORED PROMPT TO CHECK:
{authored_prompt}

TASK — Follow these steps exactly:

STEP 1: List every character name or person mentioned in the SHOT SPECIFICATION above.
STEP 2: List every character name or person mentioned in the AUTHORED PROMPT above.
STEP 3: List any characters in the AUTHORED PROMPT that are NOT in the SHOT SPECIFICATION (extras/wrong characters).
STEP 4: List any characters in the SHOT SPECIFICATION that are MISSING from the AUTHORED PROMPT.

VERDICT: If Step 3 or Step 4 found any mismatches, respond FAIL. Otherwise respond PASS.

Respond in this exact format:
SPEC_CHARACTERS: [list]
PROMPT_CHARACTERS: [list]
EXTRAS: [list or "none"]
MISSING: [list or "none"]
VERDICT: PASS or FAIL"""

        try:
            from google.genai import types as genai_types

            client = self._get_client()
            config = genai_types.GenerateContentConfig(
                temperature=0.0,
                responseModalities=["TEXT"],
            )

            response = client.models.generate_content(
                model=GATE_MODEL,
                contents=gate_prompt,
                config=config,
            )

            raw = ""
            if response and response.candidates:
                for candidate in response.candidates:
                    if candidate.content and candidate.content.parts:
                        for part in candidate.content.parts:
                            if hasattr(part, "text") and part.text:
                                raw += part.text

            verdict = "UNKNOWN"
            extras = []
            missing = []

            for line in raw.splitlines():
                line_stripped = line.strip()
                if line_stripped.startswith("VERDICT:"):
                    verdict = line_stripped.split(":", 1)[1].strip().upper()
                elif line_stripped.startswith("EXTRAS:"):
                    val = line_stripped.split(":", 1)[1].strip()
                    if val.lower() not in ("none", "[]", '["none"]', "['none']"):
                        extras = [
                            x.strip().strip("[]\"'")
                            for x in val.split(",")
                            if x.strip()
                        ]
                elif line_stripped.startswith("MISSING:"):
                    val = line_stripped.split(":", 1)[1].strip()
                    if val.lower() not in ("none", "[]", '["none"]', "['none']"):
                        missing = [
                            x.strip().strip("[]\"'")
                            for x in val.split(",")
                            if x.strip()
                        ]

            passed = verdict == "PASS"

            return ValidationResult(
                gate="gate_1_text",
                passed=passed,
                details={
                    "verdict": verdict,
                    "extras": extras,
                    "missing": missing,
                    "raw_response": raw,
                },
                model=GATE_MODEL,
                cost=0.00005,
            )

        except Exception as e:
            logger.error("Gate 1 text critique failed: %s", e)
            return ValidationResult(
                gate="gate_1_text",
                passed=False,
                details={"error": str(e)},
                model=GATE_MODEL,
                cost=0.0,
            )

    # ── Gate 1: Mechanical Artifact Detection ─────────────────────

    def run_gate_1_image(self, image_path: Path) -> ValidationResult:
        """Gate 1 (image): Mechanical QC on a generated image.

        Checks for:
        - Black frames / empty output
        - Watermarks or text overlays
        - Severe anatomical errors (extra limbs, merged faces)
        - Aspect ratio correctness
        - Extreme color banding or posterization

        Cost: ~$0.039 (1 Flash call)
        """
        if not image_path.exists():
            return ValidationResult(
                gate="gate_1_image",
                passed=False,
                details={"error": f"Image not found: {image_path}"},
                cost=0.0,
            )

        image_data = image_path.read_bytes()

        prompt = (
            "You are a mechanical QC inspector for AI-generated images.\n\n"
            "Check this image for MECHANICAL defects only (not artistic quality):\n"
            "1. BLACK_FRAME: Is the image entirely or mostly black/empty?\n"
            "2. WATERMARK: Any visible watermarks, logos, or text overlays?\n"
            "3. ANATOMY: Severe errors — extra limbs, merged body parts, impossible poses?\n"
            "4. COLOR: Extreme banding, posterization, or color corruption?\n"
            "5. RESOLUTION: Is the image blurry or extremely low resolution?\n\n"
            "Respond with JSON. Each check is pass/fail."
        )

        schema = {
            "type": "object",
            "properties": {
                "black_frame": {
                    "type": "object",
                    "properties": {
                        "pass": {"type": "boolean"},
                        "reason": {"type": "string"},
                    },
                    "required": ["pass", "reason"],
                },
                "watermark": {
                    "type": "object",
                    "properties": {
                        "pass": {"type": "boolean"},
                        "reason": {"type": "string"},
                    },
                    "required": ["pass", "reason"],
                },
                "anatomy": {
                    "type": "object",
                    "properties": {
                        "pass": {"type": "boolean"},
                        "reason": {"type": "string"},
                    },
                    "required": ["pass", "reason"],
                },
                "color": {
                    "type": "object",
                    "properties": {
                        "pass": {"type": "boolean"},
                        "reason": {"type": "string"},
                    },
                    "required": ["pass", "reason"],
                },
                "resolution": {
                    "type": "object",
                    "properties": {
                        "pass": {"type": "boolean"},
                        "reason": {"type": "string"},
                    },
                    "required": ["pass", "reason"],
                },
            },
            "required": ["black_frame", "watermark", "anatomy", "color", "resolution"],
        }

        try:
            raw = self._call_flash(prompt, images=[image_data], response_schema=schema)
            result = json.loads(raw)
            required = schema["required"]
            checks = [result.get(k) for k in required] if isinstance(result, dict) else []
            if not isinstance(result, dict) or not all(
                isinstance(c, dict) and isinstance(c.get("pass"), bool) for c in checks
            ):
                return ValidationResult(
                    gate="gate_1_image",
                    passed=False,
                    details={
                        "error": "gate response missing required checks",
                        "flagged_for_review": True,
                        "raw": result,
                    },
                    model=GATE_MODEL,
                    cost=GATE_COST_PER_CALL,
                )
            passed = all(c["pass"] is True for c in checks)
            return ValidationResult(
                gate="gate_1_image",
                passed=passed,
                details=result,
                model=GATE_MODEL,
                cost=GATE_COST_PER_CALL,
            )
        except Exception as e:
            logger.error("Gate 1 (image) failed: %s", e)
            return ValidationResult(
                gate="gate_1_image",
                passed=False,
                details={"error": str(e)},
                model=GATE_MODEL,
                cost=GATE_COST_PER_CALL,
            )

    def run_gate_1_video(self, video_path: Path) -> ValidationResult:
        """Gate 1 (video): Mechanical QC on the LAST frame of a video.

        Extracts the last frame and runs the same checks as image Gate 1.

        Cost: ~$0.039 (1 Flash call after frame extraction)
        """
        if not video_path.exists():
            return ValidationResult(
                gate="gate_1_video",
                passed=False,
                details={"error": f"Video not found: {video_path}"},
                cost=0.0,
            )

        # Extract last frame
        last_frame = _extract_last_frame(video_path)
        if last_frame is None:
            return ValidationResult(
                gate="gate_1_video",
                passed=False,
                details={"error": "Could not extract last frame from video"},
                cost=0.0,
            )

        # Reuse image gate logic
        with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
            f.write(last_frame)
            temp_path = Path(f.name)

        try:
            result = self.run_gate_1_image(temp_path)
            result.gate = "gate_1_video"
            return result
        finally:
            temp_path.unlink(missing_ok=True)

    # ── Gate 2: Semantic Keyframe Check ───────────────────────────

    # ── Severity mapping (categorical → integer) ──────────────────
    _SEVERITY_MAP = {"MINOR": 1, "NOTICEABLE": 3, "CRITICAL": 5}

    # ── Shot type framing distance (for 2B skip logic) ──────────
    _FRAMING_ORDER = {
        "EWS": 0,
        "WS": 1,
        "MWS": 2,
        "MS": 3,
        "MCU": 4,
        "CU": 5,
        "ECU": 6,
        "OTS": 3,
    }

    def _score_gate_result(self, raw_json: str, gate_name: str) -> ValidationResult:
        """Parse Flash's categorical JSON response into a scored ValidationResult."""
        result = json.loads(raw_json)
        mismatches = result.get("mismatches", [])

        # Identity override: if same_person is explicitly False, ensure CRITICAL identity mismatch
        if result.get("same_person") is False:
            has_identity_mismatch = any(
                m.get("category", "").upper() == "IDENTITY" for m in mismatches
            )
            if not has_identity_mismatch:
                mismatches.append(
                    {
                        "category": "IDENTITY",
                        "visual_evidence": "same_person=False: different individual detected",
                        "severity": "CRITICAL",
                    }
                )
                result["mismatches"] = mismatches

        # Map categorical severity to integers
        total_score = 0
        for m in mismatches:
            sev_str = m.get("severity", "MINOR")
            m["severity_int"] = self._SEVERITY_MAP.get(sev_str, 1)
            total_score += m["severity_int"]
        result["total_score"] = total_score

        if total_score <= 1:
            passed = True
        elif total_score <= 3:
            passed = False
        else:
            passed = False

        result["_action"] = (
            "accept"
            if passed
            else "regenerate"
            if total_score <= 3
            else "flag_for_review"
        )

        if mismatches:
            for m in mismatches:
                logger.info(
                    "%s mismatch: [%s] %s (%s → %d)",
                    gate_name,
                    m.get("category", "?"),
                    m.get("visual_evidence", m.get("description", "?")),
                    m.get("severity", "?"),
                    m.get("severity_int", 0),
                )
            logger.info(
                "%s total score: %d → %s", gate_name, total_score, result["_action"]
            )

        return ValidationResult(
            gate=gate_name,
            passed=passed,
            details=result,
            model=GATE_MODEL,
            cost=GATE_COST_PER_CALL,
        )

    def run_gate_2(
        self,
        keyframe_path: Path,
        ref_paths: list[Path],
        prompt_skeleton: Optional[dict] = None,
        wardrobe_phase_id: Optional[str] = None,
        wardrobe_description: Optional[str] = None,
        continuity_ref_path: Optional[Path] = None,
        preceding_shot_path: Optional[Path] = None,
        preceding_shot_type: Optional[str] = None,
        current_shot_type: Optional[str] = None,
    ) -> ValidationResult:
        """Gate 2: Progressive 2-stage visual QC.

        Stage 2A (Character): Identity, hair, wardrobe, accessories.
        Stage 2B (Environment): Background, lighting, spatial (only if 2A passes).

        Inputs per stage:
          2A: Casting refs (2-3) + text wardrobe spec + last approved frame + keyframe
          2B: Preceding shot + keyframe (skipped if shot types too different)

        Scoring: MINOR=1, NOTICEABLE=3, CRITICAL=5
          - total <= 1: PASS
          - total 2-3: FAIL + retriable (auto-regenerate)
          - total > 3: FAIL + non-retriable (flag for human review)

        Cost: $0.039 (2A only) to $0.078 (2A + 2B) per frame.
        """
        # ── Stage 2A: Character Continuity ───────────────────────
        result_2a = self.run_gate_2a_character(
            keyframe_path=keyframe_path,
            ref_paths=ref_paths,
            wardrobe_description=wardrobe_description,
            wardrobe_phase_id=wardrobe_phase_id,
            continuity_ref_path=continuity_ref_path,
            prompt_skeleton=prompt_skeleton,
            shot_type=current_shot_type,
        )

        if not result_2a.passed:
            return result_2a  # Fail fast — skip 2B

        # ── Stage 2B: Environment Continuity (conditional) ───────
        if preceding_shot_path and preceding_shot_path.exists():
            # Check framing compatibility
            skip_2b = False
            if current_shot_type and preceding_shot_type:
                cur_dist = self._FRAMING_ORDER.get(current_shot_type.upper(), 3)
                prev_dist = self._FRAMING_ORDER.get(preceding_shot_type.upper(), 3)
                if abs(cur_dist - prev_dist) > 2:
                    logger.info(
                        "Gate 2B skipped: framing too different (%s → %s)",
                        preceding_shot_type,
                        current_shot_type,
                    )
                    skip_2b = True

            if not skip_2b:
                result_2b = self.run_gate_2b_environment(
                    keyframe_path=keyframe_path,
                    preceding_shot_path=preceding_shot_path,
                    current_shot_type=current_shot_type or "?",
                    preceding_shot_type=preceding_shot_type or "?",
                )

                # Aggregate scores across both stages
                combined_score = result_2a.details.get(
                    "total_score", 0
                ) + result_2b.details.get("total_score", 0)
                combined_cost = result_2a.cost + result_2b.cost

                # Merge all mismatches
                combined_mismatches = result_2a.details.get(
                    "mismatches", []
                ) + result_2b.details.get("mismatches", [])

                # Re-evaluate pass/fail on combined score
                if combined_score <= 1:
                    combined_passed = True
                    action = "accept"
                elif combined_score <= 3:
                    combined_passed = False
                    action = "regenerate"
                else:
                    combined_passed = False
                    action = "flag_for_review"

                return ValidationResult(
                    gate="gate_2",
                    passed=combined_passed,
                    details={
                        "total_score": combined_score,
                        "mismatches": combined_mismatches,
                        "_action": action,
                        "gate_2a": result_2a.details,
                        "gate_2b": result_2b.details,
                    },
                    model=GATE_MODEL,
                    cost=combined_cost,
                )

        return result_2a

    def run_gate_2a_character(
        self,
        keyframe_path: Path,
        ref_paths: list[Path],
        wardrobe_description: Optional[str] = None,
        wardrobe_phase_id: Optional[str] = None,
        continuity_ref_path: Optional[Path] = None,
        prompt_skeleton: Optional[dict] = None,
        shot_type: Optional[str] = None,
    ) -> ValidationResult:
        """Gate 2A: Character continuity — identity, hair, wardrobe, accessories.

        Uses interleaved content parts for proper attention anchoring.
        Forces Chain-of-Thought via visual_observations field in schema.
        Categorical severity (MINOR/NOTICEABLE/CRITICAL) mapped to ints.

        For WS/EWS shots, softens to presence check only (character too small
        for identity/wardrobe comparison).

        Cost: ~$0.039
        """
        if not keyframe_path.exists():
            return ValidationResult(
                gate="gate_2a",
                passed=False,
                details={"error": f"Keyframe not found: {keyframe_path}"},
                cost=0.0,
            )

        # Filter refs by wardrobe phase if specified
        filtered_refs = ref_paths
        if wardrobe_phase_id:
            phase_filtered = [p for p in ref_paths if wardrobe_phase_id in str(p)]
            if phase_filtered:
                filtered_refs = phase_filtered

        # Limit to 3 casting refs (front, side, hero) to stay under 5 images
        casting_refs = [p for p in filtered_refs if p.exists()][:3]

        # Wide shots: character is too small for identity/wardrobe comparison.
        # Only check that a human figure is present.
        st_upper = (shot_type or "").upper()
        is_wide = st_upper in ("WS", "EWS", "MWS")
        if is_wide:
            return self._run_gate_2a_wide_shot(keyframe_path, prompt_skeleton)

        # ── Build interleaved content parts ──────────────────────
        parts = [
            "You are a strict visual continuity inspector for a cinematic microdrama.",
            "=== CASTING REFERENCES (use for facial identity and skin tone) ===",
        ]
        for p in casting_refs:
            parts.append(p.read_bytes())

        # Text wardrobe spec
        if wardrobe_description:
            parts.append(
                f"=== WARDROBE & HAIR SPECIFICATION ===\n"
                f"{wardrobe_description}\n"
                f"CRITICAL: Flag any additions not listed here. "
                f"Flag any key items missing. "
                f"For hair, 'loose' means not tied back — natural drape variation is OK."
            )

        # Continuity ref (character's last approved frame)
        if continuity_ref_path and continuity_ref_path.exists():
            parts.append(
                "=== LAST APPROVED APPEARANCE "
                "(use for current wardrobe state, hair, accessories) ==="
            )
            parts.append(continuity_ref_path.read_bytes())

        # Instructions BEFORE target keyframe — keyframe must be the absolute
        # last token for maximum recency bias attention from Flash
        comp_note = ""
        if prompt_skeleton:
            comp_note = (
                f"\nExpected composition: "
                f"Subject: {prompt_skeleton.get('subject_line', 'N/A')}, "
                f"Action: {prompt_skeleton.get('action_line', 'N/A')}"
            )

        parts.append(
            f"Compare the TARGET KEYFRAME (below) against all references above.{comp_note}\n\n"
            f"STEP 0: Answer same_person — is this the SAME PERSON as the casting references? "
            f"Yes if same person even with wardrobe/hair differences. No ONLY if clearly a different individual.\n\n"
            f"STEP 1: Fill out visual_observations — describe EXACTLY what you see "
            f"in the target keyframe's hairstyle, wardrobe, and accessories BEFORE "
            f"looking for mismatches.\n"
            f"STEP 2: List any mismatches found.\n\n"
            f"FORGIVENESS ZONES (do NOT flag these):\n"
            f"- Natural variation in how loose hair falls or drapes\n"
            f"- Wrinkles or natural folds in clothing\n"
            f"- Changes in lighting or shadows on the face\n"
            f"- Cropping (don't flag 'missing pants' in a close-up)\n"
            f"- Minor color shifts due to environmental lighting\n\n"
            f"IGNORE the white background in casting references — "
            f"the keyframe will have a scene background.\n\n"
            f"Categories: IDENTITY, HAIRSTYLE, WARDROBE, ACCESSORIES, DISTINGUISHING_MARKS\n"
            f"Severity definitions (use EXACTLY these rules):\n"
            f"  MINOR: Slight color shift, tiny detail, acceptable variation\n"
            f"  NOTICEABLE: Wrong hair structure (ponytail when should be loose), "
            f"added/missing accessory (gloves, jewelry), wrong clothing item, "
            f"missing distinguishing mark — these are fixable by regeneration\n"
            f"  CRITICAL: ONLY for completely wrong person (different face/ethnicity), "
            f"severe image corruption, or multiple simultaneous major failures\n\n"
            f"If NO mismatches, return empty mismatches array."
        )

        # Target keyframe LAST for maximum recency bias attention
        parts.append("=== TARGET KEYFRAME TO INSPECT ===")
        parts.append(keyframe_path.read_bytes())

        schema = {
            "type": "object",
            "properties": {
                "same_person": {
                    "type": "boolean",
                    "description": "Is the person in the target keyframe the SAME person as in the casting references? True if same person (even with wardrobe/hair differences), False only if clearly a different individual.",
                },
                "visual_observations": {
                    "type": "object",
                    "properties": {
                        "observed_hairstyle": {"type": "string"},
                        "observed_wardrobe": {"type": "string"},
                        "observed_accessories": {"type": "string"},
                    },
                    "required": [
                        "observed_hairstyle",
                        "observed_wardrobe",
                        "observed_accessories",
                    ],
                },
                "mismatches": {
                    "type": "array",
                    "items": {
                        "type": "object",
                        "properties": {
                            "category": {"type": "string"},
                            "visual_evidence": {"type": "string"},
                            "severity": {
                                "type": "string",
                                "enum": ["MINOR", "NOTICEABLE", "CRITICAL"],
                            },
                        },
                        "required": ["category", "visual_evidence", "severity"],
                    },
                },
            },
            "required": ["same_person", "visual_observations", "mismatches"],
        }

        try:
            raw = self._call_flash_parts(parts, response_schema=schema)
            return self._score_gate_result(raw, "gate_2a")
        except Exception as e:
            logger.error("Gate 2A failed: %s", e)
            return ValidationResult(
                gate="gate_2a",
                passed=False,
                details={"error": str(e)},
                model=GATE_MODEL,
                cost=GATE_COST_PER_CALL,
            )

    def _run_gate_2a_wide_shot(
        self,
        keyframe_path: Path,
        prompt_skeleton: Optional[dict] = None,
    ) -> ValidationResult:
        """Simplified Gate 2A for wide/establishing shots.

        At WS/EWS scale, Flash can't do identity/wardrobe comparison.
        Only checks: is a human figure present in the frame?
        """
        subject = "a character"
        if prompt_skeleton:
            subject = prompt_skeleton.get("subject_line", "a character")

        parts = [
            "You are checking a wide establishing shot for a cinematic microdrama.",
            "=== TARGET KEYFRAME ===",
            keyframe_path.read_bytes(),
            (
                f"This is a WIDE SHOT. The expected subject is: {subject}\n\n"
                f"At this scale, do NOT check identity, wardrobe, or hair details.\n"
                f"ONLY check: Is a human figure visible in the frame?\n\n"
                f"If the character is supposed to be present but the frame is empty "
                f"(pure environment), that is NOTICEABLE.\n"
                f"If a human figure is present (even small/silhouetted), PASS."
            ),
        ]

        schema = {
            "type": "object",
            "properties": {
                "visual_observations": {"type": "string"},
                "mismatches": {
                    "type": "array",
                    "items": {
                        "type": "object",
                        "properties": {
                            "category": {"type": "string"},
                            "visual_evidence": {"type": "string"},
                            "severity": {
                                "type": "string",
                                "enum": ["MINOR", "NOTICEABLE", "CRITICAL"],
                            },
                        },
                        "required": ["category", "visual_evidence", "severity"],
                    },
                },
            },
            "required": ["visual_observations", "mismatches"],
        }

        try:
            raw = self._call_flash_parts(parts, response_schema=schema)
            result = self._score_gate_result(raw, "gate_2a_wide")
            logger.info("Gate 2A (wide shot): %s", "PASS" if result.passed else "FAIL")
            return result
        except Exception as e:
            logger.error("Gate 2A (wide) failed: %s", e)
            return ValidationResult(
                gate="gate_2a_wide",
                passed=False,
                details={"error": str(e)},
                model=GATE_MODEL,
                cost=GATE_COST_PER_CALL,
            )

    def run_gate_2b_environment(
        self,
        keyframe_path: Path,
        preceding_shot_path: Path,
        current_shot_type: str = "?",
        preceding_shot_type: str = "?",
    ) -> ValidationResult:
        """Gate 2B: Environment continuity — background, lighting, spatial.

        2-image comparison: preceding shot vs target keyframe.
        Only runs if 2A passes and shot types are compatible.

        Cost: ~$0.039
        """
        if not keyframe_path.exists() or not preceding_shot_path.exists():
            return ValidationResult(
                gate="gate_2b",
                passed=True,  # Skip gracefully
                details={"skipped": "missing input images"},
                cost=0.0,
            )

        parts = [
            "You are an environment continuity inspector for a cinematic microdrama.",
            f"=== PRECEDING SHOT ({preceding_shot_type}) ===",
            preceding_shot_path.read_bytes(),
            f"=== TARGET KEYFRAME ({current_shot_type}) ===",
            keyframe_path.read_bytes(),
            (
                f"Compare the TARGET KEYFRAME's environment against the PRECEDING SHOT.\n"
                f"The preceding shot is a {preceding_shot_type}, the target is a {current_shot_type}.\n\n"
                f"STEP 1: Describe the environment in both frames before comparing.\n"
                f"STEP 2: List any environment continuity mismatches.\n\n"
                f"FORGIVENESS ZONES (do NOT flag):\n"
                f"- Elements out of frame due to different camera angle or framing\n"
                f"- Depth of field differences (background blur in CU vs sharp in WS)\n"
                f"- Minor lighting intensity changes\n"
                f"- Character position changes (characters move between shots)\n\n"
                f"ONLY flag:\n"
                f"- Background elements that contradict (e.g., wall on wrong side)\n"
                f"- Dramatic lighting direction change (key light switching sides)\n"
                f"- Set dressing that appears/disappears between shots\n\n"
                f"Categories: BACKGROUND, LIGHTING, SPATIAL\n"
                f"Severity: MINOR, NOTICEABLE, CRITICAL\n\n"
                f"If no mismatches, return empty mismatches array."
            ),
        ]

        schema = {
            "type": "object",
            "properties": {
                "visual_observations": {"type": "string"},
                "mismatches": {
                    "type": "array",
                    "items": {
                        "type": "object",
                        "properties": {
                            "category": {"type": "string"},
                            "visual_evidence": {"type": "string"},
                            "severity": {
                                "type": "string",
                                "enum": ["MINOR", "NOTICEABLE", "CRITICAL"],
                            },
                        },
                        "required": ["category", "visual_evidence", "severity"],
                    },
                },
            },
            "required": ["visual_observations", "mismatches"],
        }

        try:
            raw = self._call_flash_parts(parts, response_schema=schema)
            return self._score_gate_result(raw, "gate_2b")
        except Exception as e:
            logger.error("Gate 2B failed: %s", e)
            return ValidationResult(
                gate="gate_2b",
                passed=False,
                details={"error": str(e)},
                model=GATE_MODEL,
                cost=GATE_COST_PER_CALL,
            )

    # ── Gate 3: Video Identity Drift Spot-Check ───────────────────

    def run_gate_3(
        self,
        video_path: Path,
        ref_paths: list[Path],
        sample_pcts: Optional[list] = None,
        shot_metadata: Optional[dict] = None,
    ) -> ValidationResult:
        """Gate 3: Video identity drift detection with progressive sampling.

        Progressive strategy:
        1. Skip entirely for ENV-only shots or shots with no visible characters.
        2. Extract frame at 50% and check. If passes → done ($0.039).
        3. If 50% flags → extract 25% and 75% for confirmation ($0.078 extra).

        Motion blur handling: prompt includes "ONLY flag identity drift if
        face is clearly visible." — do not penalize motion-blurred frames.

        Gate 3 video drift → Flag for human review, do NOT auto-reject.

        Cost: $0.039 (happy path) to $0.117 (flagged, full check)
        """
        # Conditional skip: ENV-only or no-character shots
        if shot_metadata:
            is_env = shot_metadata.get("routing_data", {}).get("is_env_only", False)
            num_chars = shot_metadata.get("routing_data", {}).get("num_characters", 0)
            asset_chars = shot_metadata.get("asset_data", {}).get("characters", [])
            if is_env or (num_chars == 0 and len(asset_chars) == 0):
                logger.info("Gate 3 skipped — ENV-only or no characters")
                return ValidationResult(
                    gate="gate_3",
                    passed=True,
                    details={
                        "skipped": True,
                        "reason": "ENV-only or no visible characters",
                    },
                    cost=0.0,
                )

        if not video_path.exists():
            return ValidationResult(
                gate="gate_3",
                passed=False,
                details={"error": f"Video not found: {video_path}"},
                cost=0.0,
            )

        # Load ref images
        ref_images = []
        for p in ref_paths:
            if p.exists():
                ref_images.append(p.read_bytes())

        if not ref_images:
            return ValidationResult(
                gate="gate_3",
                passed=True,
                details={"skipped": True, "reason": "No character refs available"},
                cost=0.0,
            )

        total_cost = 0.0
        frame_results = {}

        # Step 1: Spot check at 50%
        spot_frames = _extract_frames_at_pcts(video_path, [0.50])
        spot_frame = spot_frames[0] if spot_frames else None

        if spot_frame is None:
            return ValidationResult(
                gate="gate_3",
                passed=False,
                details={
                    "error": "Could not extract frame at 50%",
                    "deferred": True,
                    "flagged_for_review": True,
                },
                cost=0.0,
            )

        spot_result = self._check_identity_drift(ref_images, spot_frame, 50)
        total_cost += GATE_COST_PER_CALL
        frame_results["50%"] = spot_result

        if spot_result.get("pass", True):
            # 50% passed — assume video is stable, done
            had_api_error = spot_result.get("api_error", False)
            return ValidationResult(
                gate="gate_3",
                passed=True,
                details={
                    "frames": frame_results,
                    "strategy": "progressive_spot_check",
                    "deferred": had_api_error,
                    "flagged_for_review": had_api_error,
                },
                model=GATE_MODEL,
                cost=total_cost,
            )

        # Step 2: 50% flagged — expand to 25% and 75% for confirmation
        logger.info("Gate 3: 50%% flagged — expanding to 25%%/75%% confirmation")
        expand_frames = _extract_frames_at_pcts(video_path, [0.25, 0.75])

        drift_count = 1  # Already have one flag from 50%
        any_api_error = spot_result.get("api_error", False)
        for pct, frame_data in zip([0.25, 0.75], expand_frames):
            if frame_data is None:
                frame_results[f"{int(pct * 100)}%"] = {
                    "pass": True,
                    "reason": "Frame extraction failed — skipped",
                }
                continue

            result = self._check_identity_drift(ref_images, frame_data, int(pct * 100))
            total_cost += GATE_COST_PER_CALL
            frame_results[f"{int(pct * 100)}%"] = result
            if result.get("api_error"):
                any_api_error = True
            if not result.get("pass", True):
                drift_count += 1

        # DEFERRED mode: drift detected → passed=True but flagged for human review.
        # Pipeline continues generating; deferred shots block final export.
        drift_detected = drift_count >= 2 or any_api_error

        return ValidationResult(
            gate="gate_3",
            passed=True,  # Always pass — Gate 3 uses DEFERRED, not auto-reject
            details={
                "frames": frame_results,
                "drift_count": drift_count,
                "strategy": "progressive_expanded",
                "flagged_for_review": drift_detected,
                "deferred": drift_detected,
            },
            model=GATE_MODEL,
            cost=total_cost,
        )

    def _check_identity_drift(
        self, ref_images: list[bytes], frame_data: bytes, pct: int
    ) -> dict:
        """Check a single video frame for identity drift against refs."""
        images = ref_images + [frame_data]

        prompt = (
            f"You are checking a video frame for CHARACTER IDENTITY DRIFT.\n\n"
            f"The FIRST {len(ref_images)} images are CHARACTER REFERENCES.\n"
            f"The LAST image is a FRAME from the generated video at {pct}%.\n\n"
            f"IMPORTANT: ONLY flag identity drift if the face is CLEARLY VISIBLE.\n"
            f"Motion blur, back-of-head shots, and silhouettes should PASS.\n\n"
            f"Check: Is this the same character as the references?\n"
            f"Respond with JSON: pass (boolean) and reason (string)."
        )

        schema = {
            "type": "object",
            "properties": {
                "pass": {"type": "boolean"},
                "reason": {"type": "string"},
            },
            "required": ["pass", "reason"],
        }

        try:
            raw = self._call_flash(prompt, images=images, response_schema=schema)
            return json.loads(raw)
        except Exception as e:
            logger.error("Gate 3 drift check at %d%% failed: %s", pct, e)
            return {"pass": True, "reason": f"Check failed: {e}", "api_error": True}


# ── Frame Extraction Helpers ─────────────────────────────────────────


def _extract_last_frame(video_path: Path) -> Optional[bytes]:
    """Extract the last frame from a video file as PNG bytes."""
    try:
        import subprocess

        with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
            temp_path = f.name

        # Use ffmpeg to extract last frame
        result = subprocess.run(
            [
                "ffmpeg",
                "-sseof",
                "-1",
                "-i",
                str(video_path),
                "-frames:v",
                "1",
                "-update",
                "1",
                "-y",
                temp_path,
            ],
            capture_output=True,
            timeout=30,
        )

        if result.returncode == 0 and Path(temp_path).exists():
            data = Path(temp_path).read_bytes()
            Path(temp_path).unlink(missing_ok=True)
            return data

        Path(temp_path).unlink(missing_ok=True)
        return None
    except Exception as e:
        logger.error("Frame extraction failed: %s", e)
        return None


def _extract_frames_at_pcts(
    video_path: Path,
    pcts: list[float],
) -> list:
    """Extract frames at specific percentage points of a video."""
    try:
        import subprocess

        # Get video duration
        probe = subprocess.run(
            [
                "ffprobe",
                "-v",
                "error",
                "-show_entries",
                "format=duration",
                "-of",
                "csv=p=0",
                str(video_path),
            ],
            capture_output=True,
            text=True,
            timeout=10,
        )
        duration = float(probe.stdout.strip())

        frames = []
        for pct in pcts:
            timestamp = duration * pct
            with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
                temp_path = f.name

            result = subprocess.run(
                [
                    "ffmpeg",
                    "-ss",
                    str(timestamp),
                    "-i",
                    str(video_path),
                    "-frames:v",
                    "1",
                    "-y",
                    temp_path,
                ],
                capture_output=True,
                timeout=30,
            )

            if result.returncode == 0 and Path(temp_path).exists():
                frames.append(Path(temp_path).read_bytes())
                Path(temp_path).unlink(missing_ok=True)
            else:
                frames.append(None)
                Path(temp_path).unlink(missing_ok=True)

        return frames
    except Exception as e:
        logger.error("Frame extraction at pcts failed: %s", e)
        return [None] * len(pcts)
