"""
video_frame_critic.py — Video Frame Validator.

CriticLoop subclass that validates generated video by extracting and
checking N frames. Catches extra limbs, style inconsistency, and
missing elements across the video.

max_attempts=1: video can't be auto-fixed, only flagged for regeneration.
"""

import logging
from pathlib import Path
from typing import Any, Optional

from recoil.core.critic import CriticLoop, Dimension, Severity
from recoil.core.vision_check import validate_video_frames

from . import FailureMode

logger = logging.getLogger(__name__)


class VideoFrameCritic(CriticLoop):
    """Validates generated video by extracting and checking N frames.

    Dimensions (aggregated across all sampled frames):
        EXTRA_LIMBS (HARD) — no characters with phantom/duplicate limbs
        STYLE_CONSISTENT (SOFT) — visual style matches expected style
        ELEMENT_PERSISTENCE (SOFT, per element) — expected elements persist

    Aggregation: A hard check fails if ANY sampled frame fails it.
    Report includes which frame indices failed.

    Args:
        character_type: "human", "quadruped", or "vehicle".
        expected_style: Description of expected visual style.
        expected_elements: List of elements that should persist across frames.
        num_frames: Number of frames to extract and check (default 5).
        experience_pool_dir: Path for JSONL logging.
        shot_id: Shot identifier for logging.
    """

    def __init__(
        self,
        character_type: str = "human",
        expected_style: str = "",
        expected_elements: Optional[list[str]] = None,
        num_frames: int = 5,
        experience_pool_dir: Optional[Path] = None,
        shot_id: str = "",
        intention_context: Optional[dict] = None,
    ):
        super().__init__(
            name="video_frame",
            max_attempts=1,
            experience_pool_dir=experience_pool_dir,
            shot_id=shot_id,
        )
        self.character_type = character_type
        self.expected_style = expected_style
        self.expected_elements = expected_elements or []
        self.num_frames = num_frames
        self.intention_context = intention_context

    def _build_checks(self) -> list[dict]:
        """Build the check list for video frame validation."""
        checks = []

        # EXTRA_LIMBS
        checks.append({
            "name": "EXTRA_LIMBS",
            "question": (
                f"Does any {self.character_type} character in this frame have "
                f"extra, phantom, or duplicate limbs, fingers, or appendages? "
                f"Answer 'no' if anatomy looks normal, 'yes' if there are extras."
            ),
            "expected": "no",
            "severity": "hard",
        })

        # STYLE_CONSISTENT
        if self.expected_style:
            checks.append({
                "name": "STYLE_CONSISTENT",
                "question": (
                    f"Does this frame match the expected visual style: "
                    f"'{self.expected_style}'? Answer 'yes' if consistent, "
                    f"'no' if the style is different."
                ),
                "expected": "yes",
                "severity": "soft",
            })

        # ELEMENT_PERSISTENCE — one check per expected element
        for element in self.expected_elements:
            checks.append({
                "name": f"ELEMENT_PERSISTENCE_{element.upper().replace(' ', '_')}",
                "question": (
                    f"Is a {element} visible in this frame? "
                    f"Answer 'yes' or 'no'."
                ),
                "expected": "yes",
                "severity": "soft",
            })

        return checks

    def evaluate(self, artifact: Any, context: dict) -> list[Dimension]:
        """Evaluate a video against all dimensions across sampled frames.

        Args:
            artifact: Path to the video file (str or Path).
            context: Optional context dict (unused).

        Returns:
            List of Dimension results (aggregated across frames).
        """
        video_path = str(artifact)
        checks = self._build_checks()

        context_desc = (
            f"Generated video frame from a {self.character_type} scene."
        )
        if self.expected_style:
            context_desc += f" Expected style: {self.expected_style}."

        result = validate_video_frames(
            video_path, checks, context_desc, num_frames=self.num_frames,
            intention_context=self.intention_context,
        )

        dims = []

        if result.get("error"):
            # Vision API failure — let CriticLoop catch and route to ERROR
            raise RuntimeError(f"VideoFrameCritic vision check failed: {result['error']}")

        frame_results = result.get("frame_results", [])
        if not frame_results:
            # Corrupted video file — must HARD fail, not silent pass
            return [Dimension(
                name="FRAME_EXTRACTION",
                severity=Severity.HARD,
                passed=False,
                message="Could not extract frames from video — file may be corrupted",
                failure_mode=FailureMode.MOTION_FAILURE,
            )]

        # Aggregate results across frames per check name
        # A check fails if ANY frame fails it
        check_names_seen = []
        for check in checks:
            name = check["name"]
            if name in check_names_seen:
                continue
            check_names_seen.append(name)

            failed_frames = []
            for fr in frame_results:
                for cr in fr.get("results", []):
                    if cr["name"] == name and not cr["passed"]:
                        failed_frames.append(fr.get("frame_index", "?"))

            severity = (
                Severity.HARD if check["severity"] == "hard"
                else Severity.SOFT
            )

            if failed_frames:
                # Map check name to failure mode
                if name == "EXTRA_LIMBS":
                    fm = FailureMode.ANATOMY_LIMB_MISCOUNT
                elif name == "STYLE_CONSISTENT":
                    fm = FailureMode.LIGHTING_MISMATCH
                elif name.startswith("ELEMENT_PERSISTENCE"):
                    fm = FailureMode.UNKNOWN
                else:
                    fm = FailureMode.UNKNOWN
                dims.append(Dimension(
                    name=name,
                    severity=severity,
                    passed=False,
                    message=f"Failed in frames: {failed_frames}",
                    failure_mode=fm,
                ))
            else:
                dims.append(Dimension(
                    name=name,
                    severity=severity,
                    passed=True,
                    message="",
                    failure_mode=None,
                ))

        return dims


# ──────────────────────────────────────────────────────────────────────────
# Structured-output adapter (Phase 2.5 — wired to real CriticLoop)
# ──────────────────────────────────────────────────────────────────────────

def video_frame_critic(
    image=None,
    reference=None,
    structured_output: bool = False,
    video_path: str | None = None,
    character_type: str = "human",
    expected_style: str = "",
    expected_elements: list | None = None,
    **kwargs,
):
    """Structured-output adapter for the video frame critic.

    Runs the real VideoFrameCritic class and translates the result.

    Required for structured_output=True: video_path (or image/reference).
    """
    if not structured_output:
        raise NotImplementedError(
            "Use VideoFrameCritic class directly for non-structured calls."
        )

    vid = video_path or image or reference
    if vid is None:
        raise ValueError(
            "video_frame_critic(structured_output=True) requires "
            "video_path (or image/reference)."
        )

    critic = VideoFrameCritic(
        character_type=character_type,
        expected_style=expected_style,
        expected_elements=expected_elements or [],
        shot_id=kwargs.get("shot_id", ""),
    )
    _, result = critic.run(str(vid))

    # Translate Dimension list → FailureMode (first hard failure wins)
    failure_mode = FailureMode.NONE.value
    for dim in result.hard_failures:
        if dim.name == "EXTRA_LIMBS":
            failure_mode = FailureMode.ANATOMY_LIMB_MISCOUNT.value
            break
        else:
            failure_mode = FailureMode.UNKNOWN.value
            break

    return {
        "passed": result.passed,
        "score": 1.0 if result.passed else 0.0,
        "failure_mode": failure_mode,
        "evidence": "; ".join(d.message for d in result.failed_dimensions if d.message),
        "dimensions": [
            {"name": d.name, "passed": d.passed, "severity": d.severity.value}
            for d in result.dimensions
        ],
    }
