"""
turnaround_critic.py — Turnaround Grid & Panel Validator.

CriticLoop subclass that validates turnaround grids and individual panels
via Gemini Flash vision checks. Used by generate_turnarounds.py when
--validate flash is set.

Two-stage validation:
  Stage 1 (grid): Check panel count, angle diversity, overall layout
  Stage 2 (panels): Per-panel checks for body completeness, eyes, wardrobe, hair, background

max_attempts=1: turnaround grids need full regeneration, not patching.
"""

import logging
from pathlib import Path
from typing import Any, Optional

from recoil.core.critic import CriticLoop, Dimension, Severity
from recoil.core.vision_check import validate_image

from . import FailureMode

logger = logging.getLogger(__name__)


class TurnaroundCritic(CriticLoop):
    """Validates turnaround grids and panels for character reference quality.

    Dimensions (grid-level):
        FOUR_DISTINCT_PANELS (HARD) — exactly 4 panels, each a different angle
        PANEL_ANGLES_CORRECT (HARD) — front, 3/4, profile, back in order
        POSE_REPETITION (HARD) — no two adjacent panels show the same pose/angle

    Dimensions (per-panel):
        BODY_COMPLETENESS (HARD) — full body or medium shot as requested
        EYE_CONSISTENCY (HARD) — both eyes same color (frontal + three-quarter only)
        WARDROBE_MATCH (HARD) — matches specified wardrobe
        HAIR_MATCH (HARD) — matches specified hair
        BACKGROUND_CLEAN (SOFT) — plain neutral gray

    Angle-aware validation: checks that are physically impossible for a given
    viewing angle (e.g. EYE_CONSISTENCY for back/profile) are skipped rather
    than failed. Empty validator responses are treated as N/A.
    """

    # Checks that require eyes to be visible — only valid for frontal and
    # three-quarter panels.  Profile and back views physically cannot show
    # both eyes, so Gemini returns empty strings → spurious hard failures.
    _EYES_REQUIRED_ANGLES = {"frontal", "three-quarter"}

    def __init__(
        self,
        wardrobe: str = "",
        hair: str = "",
        framing: str = "full_body",
        experience_pool_dir: Optional[Path] = None,
        shot_id: str = "",
    ):
        super().__init__(
            name="turnaround",
            max_attempts=1,
            experience_pool_dir=experience_pool_dir,
            shot_id=shot_id,
        )
        self.wardrobe = wardrobe
        self.hair = hair
        self.framing = framing

    def _build_grid_checks(self) -> list[dict]:
        """Build check list for the full grid image."""
        return [
            {
                "name": "FOUR_DISTINCT_PANELS",
                "question": (
                    "Does this image show exactly 4 separate panels of the same character, "
                    "each from a clearly different angle? Answer 'yes' or 'no'."
                ),
                "expected": "yes",
                "severity": "hard",
            },
            {
                "name": "PANEL_ANGLES_CORRECT",
                "question": (
                    "Reading left to right: Is panel 1 a frontal view, panel 2 a three-quarter "
                    "angle (about 45 degrees), panel 3 a profile/side view (about 90 degrees), "
                    "and panel 4 a back view? Answer 'yes' if all 4 match, or 'no' and describe "
                    "which panel has the wrong angle."
                ),
                "expected": "yes",
                "severity": "hard",
            },
            {
                "name": "POSE_REPETITION",
                "question": (
                    "Look at all 4 panels. Are there any two adjacent panels that show "
                    "essentially the same angle/pose? Panel 1 should be frontal, Panel 2 "
                    "should be three-quarter (both eyes visible, slight turn), Panel 3 should "
                    "be full profile (side view, one eye), Panel 4 should be back view. "
                    "Answer 'yes' if all four panels show clearly DIFFERENT angles, 'no' if "
                    "any two panels are too similar."
                ),
                "expected": "yes",
                "severity": "hard",
            },
        ]

    def _build_panel_checks(self, angle: str = "frontal") -> list[dict]:
        """Build check list for individual panels.

        Args:
            angle: Panel angle name (frontal, three-quarter, profile, back).
                   Used to skip checks that are physically impossible for the
                   given view (e.g. EYE_CONSISTENCY for back panels).
        """
        checks = []

        if self.framing == "full_body":
            checks.append({
                "name": "BODY_COMPLETENESS",
                "question": (
                    "Is this a full-body image showing the character from head to toe? "
                    "Is the head visible at the top and feet visible at the bottom? "
                    "Answer 'yes' if complete, 'no' if any part is cut off."
                ),
                "expected": "yes",
                "severity": "hard",
            })
        else:
            checks.append({
                "name": "BODY_COMPLETENESS",
                "question": (
                    "Is this a medium shot showing the character from the waist up? "
                    "Is the face clearly visible with good detail? "
                    "Answer 'yes' if framing is waist-up or closer."
                ),
                "expected": "yes",
                "severity": "hard",
            })

        # EYE_CONSISTENCY only applies when both eyes are visible
        if angle in self._EYES_REQUIRED_ANGLES:
            checks.append({
                "name": "EYE_CONSISTENCY",
                "question": (
                    "Look at the character's eyes carefully. Are both eyes the same color? "
                    "Answer 'yes' if same color, 'no' if different colors (heterochromia)."
                ),
                "expected": "yes",
                "severity": "hard",
            })

        if self.wardrobe:
            checks.append({
                "name": "WARDROBE_MATCH",
                "question": f"Is the character wearing: {self.wardrobe}? Answer 'yes' or 'no'.",
                "expected": "yes",
                "severity": "hard",
            })

        if self.hair:
            checks.append({
                "name": "HAIR_MATCH",
                "question": (
                    f"Does the character's hair match this description: {self.hair}? "
                    f"Answer 'yes' or 'no'."
                ),
                "expected": "yes",
                "severity": "hard",
            })

        checks.append({
            "name": "BACKGROUND_CLEAN",
            "question": (
                "Is the background a plain neutral gray color with no scene elements, "
                "objects, or environment? Answer 'yes' for clean, 'no' for contaminated."
            ),
            "expected": "yes",
            "severity": "soft",
        })

        return checks

    def evaluate(self, artifact: Any, context: dict) -> list[Dimension]:
        """Evaluate a turnaround grid image.

        Args:
            artifact: Path to the grid image (str or Path).
            context: Dict with optional 'panel_paths' key for per-panel checks.

        Returns:
            List of Dimension results.
        """
        grid_path = str(artifact)
        dims = []

        # Stage 1: Grid-level checks
        grid_checks = self._build_grid_checks()
        grid_result = validate_image(
            grid_path, grid_checks,
            context_description="Character turnaround grid — 4 angles side by side",
        )

        if grid_result.get("error"):
            # Vision API failure — let CriticLoop catch and route to ERROR
            raise RuntimeError(f"TurnaroundCritic grid vision check failed: {grid_result['error']}")

        grid_failed = False
        for check_result in grid_result.get("results", []):
            severity = Severity.HARD if check_result["severity"] == "hard" else Severity.SOFT
            passed = check_result["passed"]
            message = ""
            if not passed:
                message = f"Expected '{check_result['expected']}', got '{check_result['answer']}'"
                if severity == Severity.HARD:
                    grid_failed = True
            dims.append(Dimension(name=check_result["name"], severity=severity,
                                   passed=passed, message=message,
                                   failure_mode=None if passed else FailureMode.UNKNOWN))

        # Stage 2: Per-panel checks (only if grid passed)
        panel_paths = context.get("panel_paths", [])
        if grid_failed or not panel_paths:
            return dims

        angle_names = ["frontal", "three-quarter", "profile", "back"]

        for i, panel_path in enumerate(panel_paths):
            angle = angle_names[i] if i < len(angle_names) else f"panel_{i}"
            panel_checks = self._build_panel_checks(angle=angle)
            panel_result = validate_image(
                str(panel_path), panel_checks,
                context_description=f"Character turnaround panel — {angle} view",
            )

            if panel_result.get("error"):
                # Vision API failure — let CriticLoop catch and route to ERROR
                raise RuntimeError(
                    f"TurnaroundCritic panel {angle} vision check failed: {panel_result['error']}"
                )

            for check_result in panel_result.get("results", []):
                severity = Severity.HARD if check_result["severity"] == "hard" else Severity.SOFT
                answer = check_result["answer"]
                passed = check_result["passed"]

                # Gemini sometimes returns empty strings for ambiguous panels
                # (e.g. wardrobe/hair hard to judge from back). For HARD
                # checks, treat empty as a real failure with UNDECIDED message
                # so morning triage can review. For SOFT checks, log and pass.
                # (Audit finding turnaround:1882, 2026-04-09 — was silent bypass.)
                #
                # DEFENSIVE comparison (Gemini Finding 7): the local `severity`
                # variable might be a Severity enum OR a string "hard"/"soft"
                # depending on how the upstream check_result was constructed.
                # Compare both forms to avoid silently falling through.
                if not passed and answer == "":
                    is_hard = (
                        severity == Severity.HARD
                        or str(severity).lower().endswith("hard")
                    )
                    if is_hard:
                        logger.warning(
                            "TurnaroundCritic: %s_%s got empty response on HARD check — UNDECIDED, failing closed",
                            check_result["name"], angle,
                        )
                        dims.append(Dimension(
                            name=f"{check_result['name']}_{angle}",
                            severity=severity,
                            passed=False,
                            message=f"[{angle}] UNDECIDED — empty validator response on HARD check",
                            failure_mode=FailureMode.UNKNOWN,
                        ))
                    else:
                        logger.info(
                            "TurnaroundCritic: %s_%s got empty response on SOFT check — N/A pass",
                            check_result["name"], angle,
                        )
                        dims.append(Dimension(
                            name=f"{check_result['name']}_{angle}",
                            severity=severity,
                            passed=True,
                            message=f"[{angle}] N/A — empty response (SOFT)",
                            failure_mode=None,
                        ))
                    continue

                message = ""
                if not passed:
                    message = (f"[{angle}] Expected '{check_result['expected']}', "
                               f"got '{answer}'")
                # Map per-panel check name to failure mode
                if passed:
                    fm = None
                else:
                    cname = check_result["name"]
                    if cname == "WARDROBE_MATCH":
                        fm = FailureMode.WARDROBE_MISMATCH
                    elif cname in ("EYE_CONSISTENCY", "HAIR_MATCH"):
                        fm = FailureMode.IDENTITY_DRIFT
                    elif cname == "BACKGROUND_CLEAN":
                        fm = FailureMode.BACKGROUND_CONTAMINATION
                    else:
                        fm = FailureMode.UNKNOWN
                dims.append(Dimension(
                    name=f"{check_result['name']}_{angle}",
                    severity=severity,
                    passed=passed,
                    message=message,
                    failure_mode=fm,
                ))

        return dims


def validate_turnaround_grid(grid_img, panels, wardrobe: str, hair: str,
                              framing: str) -> tuple[bool, str]:
    """Convenience function for generate_turnarounds.py.

    Saves grid + panels to temp files, runs TurnaroundCritic, returns (passed, reason).
    """
    import tempfile
    import os

    # Save grid to temp
    with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
        grid_img.save(f, "PNG")
        grid_tmp = f.name

    # Save panels to temp
    panel_tmps = []
    for i, panel in enumerate(panels):
        with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
            panel.save(f, "PNG")
            panel_tmps.append(f.name)

    try:
        critic = TurnaroundCritic(wardrobe=wardrobe, hair=hair, framing=framing)
        _, result = critic.run(grid_tmp, context={"panel_paths": panel_tmps})

        if result.hard_failures:
            reasons = [f"{d.name}: {d.message}" for d in result.hard_failures]
            return False, "; ".join(reasons)

        return True, "OK"

    finally:
        os.unlink(grid_tmp)
        for p in panel_tmps:
            os.unlink(p)


# ──────────────────────────────────────────────────────────────────────────
# Structured-output adapter (Phase 2.5 — wired to real CriticLoop)
# ──────────────────────────────────────────────────────────────────────────

def turnaround_critic(
    image=None,
    reference=None,
    structured_output: bool = False,
    grid_path: str | None = None,
    panel_paths: list | None = None,
    wardrobe: str = "",
    hair: str = "",
    framing: str = "full_body",
    **kwargs,
):
    """Structured-output adapter for the turnaround critic.

    Runs the real TurnaroundCritic class and translates the result.

    Required for structured_output=True: grid_path (or image/reference).
    """
    if not structured_output:
        raise NotImplementedError(
            "Use TurnaroundCritic class directly for non-structured calls."
        )

    grid = grid_path or image or reference
    if grid is None:
        raise ValueError(
            "turnaround_critic(structured_output=True) requires "
            "grid_path (or image/reference)."
        )

    critic = TurnaroundCritic(
        wardrobe=wardrobe,
        hair=hair,
        framing=framing,
        shot_id=kwargs.get("shot_id", ""),
    )
    context = {"panel_paths": panel_paths or []}
    _, result = critic.run(str(grid), context=context)

    # Translate Dimension list → FailureMode (first hard failure wins)
    failure_mode = FailureMode.NONE.value
    for dim in result.hard_failures:
        dim_name = dim.name.upper()
        # Strip angle suffix for per-panel checks (e.g. WARDROBE_MATCH_frontal)
        base_name = dim_name.rsplit("_", 1)[0] if "_" in dim_name else dim_name
        if any(bn in dim_name for bn in ["WARDROBE_MATCH", "WARDROBE"]):
            failure_mode = FailureMode.WARDROBE_MISMATCH.value
            break
        elif any(bn in dim_name for bn in ["EYE_CONSISTENCY", "HAIR_MATCH", "HAIR"]):
            failure_mode = FailureMode.IDENTITY_DRIFT.value
            break
        elif "BACKGROUND_CLEAN" in dim_name:
            failure_mode = FailureMode.BACKGROUND_CONTAMINATION.value
            break
        elif dim_name in ("FOUR_DISTINCT_PANELS", "PANEL_ANGLES_CORRECT", "POSE_REPETITION"):
            failure_mode = FailureMode.UNKNOWN.value
            break

    return {
        "passed": result.passed,
        "score": 1.0 if result.passed else 0.0,
        "failure_mode": failure_mode,
        "evidence": "; ".join(d.message for d in result.failed_dimensions if d.message),
        "dimensions": [
            {"name": d.name, "passed": d.passed, "severity": d.severity.value}
            for d in result.dimensions
        ],
    }
