"""
character_consistency_critic.py — Cross-Shot Character Consistency Validator.

Groups all appearances of the same character across shots, creates a
comparison grid image, and asks Gemini Flash to identify outliers.
Flagged shots get pair-level comparison at higher resolution.

Two-pass approach:
  Pass 1 (Grid): All shots for one character in a single grid image.
         Flash identifies which shots look inconsistent. Cost: ~$0.01/character.
  Pass 2 (Pairs): Flagged shots compared 1:1 against the best reference.
         Flash identifies specific inconsistencies. Cost: ~$0.01/pair.
"""

import logging
from pathlib import Path
from typing import Optional

from PIL import Image as PILImage

from recoil.core.critic import FailureMode
from recoil.core.vision_check import validate_image

logger = logging.getLogger(__name__)


def _create_grid(image_paths: list[Path], max_cols: int = 4, cell_size: int = 512) -> Path:
    """Create a labeled grid image from multiple shots.

    Each cell is labeled with its shot index (1-based) in the top-left corner.
    Returns path to temporary grid image.
    """
    import math
    import tempfile

    n = len(image_paths)
    cols = min(n, max_cols)
    rows = math.ceil(n / cols)

    grid = PILImage.new("RGB", (cols * cell_size, rows * cell_size), (40, 40, 40))

    for i, path in enumerate(image_paths):
        try:
            img = PILImage.open(path).convert("RGB")
            img = img.resize((cell_size, cell_size), PILImage.LANCZOS)
            col = i % cols
            row = i // cols
            grid.paste(img, (col * cell_size, row * cell_size))
        except Exception as e:
            logger.warning("Could not load image %s for grid: %s", path, e)

    tmp = tempfile.NamedTemporaryFile(suffix="_consistency_grid.jpg", delete=False)
    grid_path = Path(tmp.name)
    tmp.close()
    grid.save(grid_path, quality=90)
    return grid_path


def check_character_consistency(
    character_name: str,
    shot_frames: dict[str, Path],
    character_anchor: str = "",
    experience_pool_dir: Optional[Path] = None,
) -> dict:
    """Run cross-shot consistency check for one character.

    Args:
        character_name: Name of the character being checked.
        shot_frames: Dict of {shot_id: frame_path} for all shots featuring this character.
        character_anchor: Frozen character description (identity anchor).
        experience_pool_dir: Optional path for JSONL logging.

    Returns:
        {
            "character": str,
            "total_shots": int,
            "consistent": bool,
            "outlier_shots": list[str],  # shot IDs flagged as inconsistent
            "details": list[dict],       # per-pair results for flagged shots
        }
    """
    shot_ids = list(shot_frames.keys())
    frame_paths = [shot_frames[sid] for sid in shot_ids]

    if len(frame_paths) < 2:
        return {
            "character": character_name,
            "total_shots": len(frame_paths),
            "consistent": True,
            "outlier_shots": [],
            "details": [],
            "failure_mode": FailureMode.NONE.value,
        }

    # -- Pass 1: Grid comparison -------------------------------------------
    grid_path = _create_grid(frame_paths)

    grid_checks = [
        {
            "name": "CONSISTENCY_OVERVIEW",
            "question": (
                f"This grid shows {len(frame_paths)} appearances of the character "
                f"'{character_name}' across different shots (labeled 1-{len(frame_paths)}). "
                f"Character description: {character_anchor or 'not provided'}. "
                "Which shot numbers (if any) look visually inconsistent with the others? "
                "Check: hair color/style, skin tone, clothing, facial features, "
                "accessories. List the inconsistent shot numbers, or say 'all consistent'."
            ),
            "expected": "all consistent",
            "severity": "SOFT",
        }
    ]

    try:
        grid_result = validate_image(
            image_path=str(grid_path),
            checks=grid_checks,
            context_description=f"Character consistency grid for '{character_name}'",
            intention_context={
                "character_anchor": character_anchor,
            } if character_anchor else None,
        )

        # Parse outlier shot numbers from Flash's response
        outlier_indices = []
        if grid_result.get("results"):
            answer = grid_result["results"][0].get("answer", "")
            if "all consistent" not in answer.lower():
                # Extract numbers from the response
                import re
                numbers = re.findall(r'\b(\d+)\b', answer)
                outlier_indices = [int(n) - 1 for n in numbers if 0 < int(n) <= len(shot_ids)]
    finally:
        # Clean up grid temp file even if validate_image throws
        try:
            grid_path.unlink()
        except OSError:
            pass

    if not outlier_indices:
        return {
            "character": character_name,
            "total_shots": len(frame_paths),
            "consistent": True,
            "outlier_shots": [],
            "details": [],
            "failure_mode": FailureMode.NONE.value,
        }

    # -- Pass 2: Pair comparison for flagged shots -------------------------
    # Use the first non-outlier shot as the reference
    reference_idx = next(
        (i for i in range(len(shot_ids)) if i not in outlier_indices),
        0  # fallback to first shot if all are outliers
    )
    reference_path = frame_paths[reference_idx]

    pair_details = []

    for idx in outlier_indices:
        if idx >= len(shot_ids):
            continue
        outlier_path = frame_paths[idx]
        outlier_id = shot_ids[idx]

        pair_checks = [
            {
                "name": "HAIR",
                "question": "Compare the character's hair between these two images. Is the hair color, length, and style the same?",
                "expected": "yes",
                "severity": "HARD",
            },
            {
                "name": "CLOTHING",
                "question": "Compare the character's clothing between these two images. Is the clothing color, style, and type the same?",
                "expected": "yes",
                "severity": "HARD",
            },
            {
                "name": "SKIN_TONE",
                "question": "Compare the character's skin tone between these two images. Is it noticeably different?",
                "expected": "no",
                "severity": "SOFT",
            },
            {
                "name": "ACCESSORIES",
                "question": "Compare accessories (glasses, jewelry, hats) between these two images. Are the same accessories present in both?",
                "expected": "yes",
                "severity": "SOFT",
            },
        ]

        # Create a side-by-side comparison image
        try:
            ref_img = PILImage.open(reference_path).convert("RGB").resize((512, 512))
            out_img = PILImage.open(outlier_path).convert("RGB").resize((512, 512))
            pair_img = PILImage.new("RGB", (1024, 512))
            pair_img.paste(ref_img, (0, 0))
            pair_img.paste(out_img, (512, 0))

            import tempfile
            tmp = tempfile.NamedTemporaryFile(suffix="_consistency_pair.jpg", delete=False)
            pair_path = Path(tmp.name)
            tmp.close()
            pair_img.save(pair_path, quality=90)

            try:
                pair_result = validate_image(
                    image_path=str(pair_path),
                    checks=pair_checks,
                    context_description=(
                        f"Character consistency pair: '{character_name}'. "
                        f"Left image is the reference (shot {shot_ids[reference_idx]}). "
                        f"Right image is the outlier (shot {outlier_id})."
                    ),
                    intention_context={
                        "character_anchor": character_anchor,
                    } if character_anchor else None,
                )
            finally:
                try:
                    pair_path.unlink()
                except OSError:
                    pass

            failed_checks = [
                r for r in pair_result.get("results", [])
                if not r.get("passed")
            ]

            pair_details.append({
                "shot_id": outlier_id,
                "reference_shot": shot_ids[reference_idx],
                "passed": len(failed_checks) == 0,
                "failed_checks": [
                    {"name": r["name"], "answer": r["answer"], "severity": r["severity"]}
                    for r in failed_checks
                ],
            })

        except Exception as e:
            logger.error("Pair comparison failed for shot %s: %s", outlier_id, e)
            # Fail-closed: re-raise so the calling CriticLoop routes to ERROR
            raise RuntimeError(
                f"Pair comparison crashed for shot {outlier_id}: {e}"
            ) from e

    confirmed_outliers = [d["shot_id"] for d in pair_details if not d.get("passed", True)]

    return {
        "character": character_name,
        "total_shots": len(frame_paths),
        "consistent": len(confirmed_outliers) == 0,
        "outlier_shots": confirmed_outliers,
        "details": pair_details,
        "failure_mode": FailureMode.IDENTITY_DRIFT.value if confirmed_outliers else FailureMode.NONE.value,
    }


# ──────────────────────────────────────────────────────────────────────────
# Structured-output adapter (Phase 25, T1.16)
# ──────────────────────────────────────────────────────────────────────────
# Full Gemini response_schema integration is deferred to Phase 2. For now we
# expose a thin adapter that accepts ``structured_output=True`` and returns a
# dict with a FailureMode-typed ``failure_mode`` field, so the feedback agent
# has a stable contract to code against.

def character_consistency_critic(
    image=None,
    reference=None,
    structured_output: bool = False,
    **kwargs,
):
    """Structured-output alias for cross-shot character consistency."""
    if not structured_output:
        return check_character_consistency(
            character_name=kwargs.get("character_name", ""),
            shot_frames=kwargs.get("shot_frames", {}),
            character_anchor=kwargs.get("character_anchor", ""),
            experience_pool_dir=kwargs.get("experience_pool_dir"),
        )

    shot_frames = kwargs.get("shot_frames", {})
    if not shot_frames or len(shot_frames) < 2:
        raise ValueError(
            "character_consistency_critic(structured_output=True) requires "
            "shot_frames dict with at least 2 entries."
        )

    try:
        result = check_character_consistency(
            character_name=kwargs.get("character_name", ""),
            shot_frames=shot_frames,
            character_anchor=kwargs.get("character_anchor", ""),
            experience_pool_dir=kwargs.get("experience_pool_dir"),
        )
        consistent = bool(result.get("consistent", True))
        return {
            "passed": consistent,
            "score": 1.0 if consistent else 0.0,
            "failure_mode": (
                FailureMode.NONE.value if consistent
                else FailureMode.IDENTITY_DRIFT.value
            ),
            "evidence": f"outliers={result.get('outlier_shots', [])}",
        }
    except Exception as exc:
        # Per Task 2: vision API failures route to ERROR via the consumer.
        # We re-raise so the caller's CriticLoop wrapper handles it as ERROR.
        raise RuntimeError(f"character_consistency_critic vision check failed: {exc}") from exc
