"""
blocking_validators.py — Validation rules for Stage 2.5 Blocking Pass.

Hard validators block plan write-back on failure.
Soft validators log warnings but allow the plan to be saved.

Usage:
    from recoil.pipeline._lib.blocking_validators import run_all_validators
    violations = run_all_validators(plan_dict, bible)
"""

from __future__ import annotations

import re
from typing import Optional

from recoil.pipeline._lib.render_schema import GlobalBible


# ---------------------------------------------------------------------------
# Appearance rejection patterns
# ---------------------------------------------------------------------------

_APPEARANCE_PATTERNS = [
    re.compile(r"\b(young|old|elderly|middle-aged)\s+(woman|man|person|figure)\b", re.I),
    re.compile(r"\b(dark|light|blonde|red|brown|black|grey|gray|white)\s+(hair|haired)\b", re.I),
    re.compile(r"\b(blue|green|brown|hazel|dark)\s+eyes\b", re.I),
    re.compile(r"\b(tall|short|stocky|thin|muscular|lean)\s+(figure|build|frame)\b", re.I),
    re.compile(r"\bwearing\s+(a\s+)?(dress|shirt|jacket|pants|skirt|coat|uniform|boots|gloves|outfit|suit|vest|armor|helmet)\b", re.I),
    re.compile(r"\b(grime-streaked|scarred|tattooed)\s+(face|skin|complexion)\b", re.I),
]


# ---------------------------------------------------------------------------
# Hard Validators — block write-back on failure
# ---------------------------------------------------------------------------


def validate_schema_compliance(plan_dict: dict) -> list[str]:
    """Every non-ENV shot with blocking_metadata must have valid structure."""
    errors = []
    for shot in plan_dict.get("shots", []):
        shot_id = shot.get("shot_id", "?")
        is_env = shot.get("routing_data", {}).get("is_env_only", False)
        bm = shot.get("blocking_metadata")

        if is_env:
            continue  # ENV shots may or may not have blocking

        if bm is None:
            # Only error if blocking pass has run (check if ANY shot has it)
            continue

        chars = bm.get("characters", [])
        if not chars and not is_env:
            errors.append(f"[HARD] {shot_id}: blocking_metadata.characters is empty for non-ENV shot")

        for c in chars:
            if not c.get("character_id"):
                errors.append(f"[HARD] {shot_id}: character_id missing in blocking")
            dh = c.get("dominant_hand", {})
            sh = c.get("secondary_hand", {})
            if not dh.get("hand"):
                errors.append(f"[HARD] {shot_id}/{c.get('character_id', '?')}: dominant_hand.hand missing")
            if not sh.get("hand"):
                errors.append(f"[HARD] {shot_id}/{c.get('character_id', '?')}: secondary_hand.hand missing")

    return errors


def validate_character_parity(plan_dict: dict) -> list[str]:
    """Characters in blocking must match characters in asset_data."""
    errors = []
    for shot in plan_dict.get("shots", []):
        shot_id = shot.get("shot_id", "?")
        bm = shot.get("blocking_metadata")
        if bm is None:
            continue

        is_env = shot.get("routing_data", {}).get("is_env_only", False)
        if is_env:
            continue

        asset_chars = set()
        for c in shot.get("asset_data", {}).get("characters", []):
            cid = c.get("char_id", c) if isinstance(c, dict) else c
            # Only count in-frame characters
            vis = c.get("visibility", "in_frame") if isinstance(c, dict) else "in_frame"
            if vis == "in_frame":
                asset_chars.add(cid)

        blocking_chars = set(c.get("character_id") for c in bm.get("characters", []))

        if asset_chars != blocking_chars:
            missing = asset_chars - blocking_chars
            extra = blocking_chars - asset_chars
            parts = []
            if missing:
                parts.append(f"missing blocking for {missing}")
            if extra:
                parts.append(f"extra blocking for {extra}")
            errors.append(f"[HARD] {shot_id}: character parity mismatch — {', '.join(parts)}")

    return errors


def validate_prop_id_validity(plan_dict: dict, bible: GlobalBible) -> list[str]:
    """Check prop_ids in hand states against bible props.

    Unknown props are soft warnings — Gemini may infer contextual props
    (rebreather, tether, cable) from the script that aren't registered
    in the bible's prop registry. These are descriptive, not identity-critical.
    """
    errors = []
    valid_props = set(bible.props.keys())

    for shot in plan_dict.get("shots", []):
        shot_id = shot.get("shot_id", "?")
        bm = shot.get("blocking_metadata")
        if bm is None:
            continue

        for c in bm.get("characters", []):
            for hand_key in ("dominant_hand", "secondary_hand"):
                hand = c.get(hand_key, {})
                pid = hand.get("prop_id")
                if pid and pid not in valid_props:
                    errors.append(
                        f"[WARN] {shot_id}/{c.get('character_id', '?')}: "
                        f"{hand_key}.prop_id='{pid}' not in bible props"
                    )

    return errors


def validate_env_shots(plan_dict: dict) -> list[str]:
    """ENV shots must have empty characters in blocking."""
    errors = []
    for shot in plan_dict.get("shots", []):
        shot_id = shot.get("shot_id", "?")
        is_env = shot.get("routing_data", {}).get("is_env_only", False)
        bm = shot.get("blocking_metadata")

        if is_env and bm and bm.get("characters"):
            errors.append(
                f"[HARD] {shot_id}: ENV shot has {len(bm['characters'])} "
                f"characters in blocking_metadata"
            )

    return errors


def validate_appearance_rejection(plan_dict: dict) -> list[str]:
    """subject_line should not contain appearance descriptors after blocking pass."""
    errors = []
    for shot in plan_dict.get("shots", []):
        shot_id = shot.get("shot_id", "?")
        bm = shot.get("blocking_metadata")
        if bm is None:
            continue  # Only check shots that went through blocking pass

        subject = shot.get("prompt_data", {}).get("prompt_skeleton", {}).get("subject_line", "")
        for pattern in _APPEARANCE_PATTERNS:
            match = pattern.search(subject)
            if match:
                errors.append(
                    f"[HARD] {shot_id}: subject_line contains appearance term: "
                    f"'{match.group()}'"
                )
                break  # One hit per shot is enough

    return errors


# ---------------------------------------------------------------------------
# Soft Validators — log warnings, don't block
# ---------------------------------------------------------------------------


def validate_prop_continuity(plan_dict: dict) -> list[str]:
    """Check prop handoff continuity within scenes."""
    warnings = []

    # Group shots by scene
    scenes: dict[int, list[dict]] = {}
    for shot in plan_dict.get("shots", []):
        si = shot.get("scene_index", 1)
        scenes.setdefault(si, []).append(shot)

    for scene_index, scene_shots in scenes.items():
        # Sort by shot_id
        scene_shots.sort(key=lambda s: s.get("shot_id", ""))

        # Track prop holders across shots
        prop_last_state: dict[str, tuple[str, str]] = {}  # prop_id → (holder, shot_id)

        for shot in scene_shots:
            shot_id = shot.get("shot_id", "?")
            bm = shot.get("blocking_metadata")
            if bm is None:
                continue

            # Build current prop holders from character hand states
            current_holders: dict[str, str] = {}
            for c in bm.get("characters", []):
                char_id = c.get("character_id", "?")
                for hand_key in ("dominant_hand", "secondary_hand"):
                    hand = c.get(hand_key, {})
                    pid = hand.get("prop_id")
                    if pid:
                        if pid in current_holders:
                            warnings.append(
                                f"[WARN] {shot_id}: {pid} held by multiple "
                                f"characters: {current_holders[pid]} and {char_id}"
                            )
                        current_holders[pid] = char_id

            # Check for props that disappeared without narrative justification
            for prop_id, (last_holder, last_shot) in prop_last_state.items():
                if prop_id not in current_holders and prop_id not in bm.get("prop_states", {}):
                    warnings.append(
                        f"[WARN] Scene {scene_index}/{shot_id}: prop '{prop_id}' "
                        f"was held by {last_holder} in {last_shot} but is now absent "
                        f"from both character hands and prop_states"
                    )

            # Update tracking
            for prop_id, holder in current_holders.items():
                prop_last_state[prop_id] = (holder, shot_id)

    return warnings


def validate_gaze_axis(plan_dict: dict) -> list[str]:
    """Check gaze direction consistency with 180-degree rule within scenes."""
    warnings = []

    scenes: dict[int, list[dict]] = {}
    for shot in plan_dict.get("shots", []):
        si = shot.get("scene_index", 1)
        scenes.setdefault(si, []).append(shot)

    for scene_index, scene_shots in scenes.items():
        scene_shots.sort(key=lambda s: s.get("shot_id", ""))

        # Check axis_violation flags
        for shot in scene_shots:
            bm = shot.get("blocking_metadata")
            if bm and bm.get("axis_violation"):
                warnings.append(
                    f"[WARN] {shot.get('shot_id', '?')}: axis_violation flag set — "
                    f"spatial_data may contradict gaze axis"
                )

    return warnings


def validate_shot_type_detail(plan_dict: dict) -> list[str]:
    """Check that blocking detail matches shot type."""
    warnings = []
    for shot in plan_dict.get("shots", []):
        shot_id = shot.get("shot_id", "?")
        shot_type = shot.get("prompt_data", {}).get("shot_type", "MS")
        bm = shot.get("blocking_metadata")
        if bm is None:
            continue

        for c in bm.get("characters", []):
            wb = c.get("weight_bearing")
            if shot_type in ("CU", "ECU", "MCU") and wb:
                warnings.append(
                    f"[INFO] {shot_id}: {shot_type} shot has weight_bearing='{wb}' — "
                    f"will be ignored at this shot size"
                )

    return warnings


# ---------------------------------------------------------------------------
# Runner
# ---------------------------------------------------------------------------

def run_all_validators(
    plan_dict: dict,
    bible: GlobalBible,
) -> list[str]:
    """Run all validators and return combined list of issues.

    Hard validators are prefixed with [HARD], soft with [WARN] or [INFO].
    """
    # Check if blocking pass has actually run
    has_blocking = any(
        shot.get("blocking_metadata") is not None
        for shot in plan_dict.get("shots", [])
    )
    if not has_blocking:
        return []

    issues = []

    # Hard validators
    issues.extend(validate_schema_compliance(plan_dict))
    issues.extend(validate_character_parity(plan_dict))
    issues.extend(validate_prop_id_validity(plan_dict, bible))
    issues.extend(validate_env_shots(plan_dict))
    issues.extend(validate_appearance_rejection(plan_dict))

    # Soft validators
    issues.extend(validate_prop_continuity(plan_dict))
    issues.extend(validate_gaze_axis(plan_dict))
    issues.extend(validate_shot_type_detail(plan_dict))

    return issues
