"""
keyframe_rewrite_critic.py — IP3: Keyframe Rewrite Critic.

Runs AFTER Flash text rewrite in keyframe_context.build_smart_prompt().
Five deterministic dimensions — zero LLM calls, <1ms per shot.

Dimensions:
  SUBJECT_PRIMACY   — Character identity in first 30 words (SOFT)
  ENVIRONMENT_LOCK  — Key environment terms from bible survive rewrite (SOFT)
  SPATIAL_SURVIVAL  — Spatial block terms survive rewrite (SOFT)
  ARCHETYPE_DETOX   — No archetype trigger words (HARD)
  VFX_BLACKLIST     — No face-artifact-causing terms (HARD)
"""

import logging
import re
from typing import Any

from recoil.core.critic import CriticLoop, Dimension, Severity

from . import FailureMode

logger = logging.getLogger(__name__)

# ── Archetype trigger patterns (from prompt_engine._ARCHETYPE_SCRUBS) ──
_ARCHETYPE_PATTERNS = [
    (re.compile(r"\btactical pants\b", re.I), "cargo pants"),
    (re.compile(r"\btactical assessment\b", re.I), "careful assessment"),
    (re.compile(r"\btactical gloves?\b", re.I), "gloves"),
    (re.compile(r"\btactical\b", re.I), "practical"),
    (re.compile(r"\boperator\b", re.I), "professional"),
    (re.compile(r"\bspecial forces\b", re.I), "trained"),
    (re.compile(r"\bmilitary-grade\b", re.I), "industrial"),
    (re.compile(r"\bcombat\s+ready\b", re.I), "prepared"),
    (re.compile(r"\bwarrior\b", re.I), "fighter"),
    (re.compile(r"\bcommando\b", re.I), "operative"),
]

# ── VFX blacklist (terms that cause face artifacts in NBP) ──
_VFX_BLACKLIST = re.compile(
    r"\b(?:morphing|dissolve|face\s*swap|double\s*exposure|"
    r"glitch(?:ing)?|pixelat(?:ed|ion)|overlay|superimpos(?:ed|e)|"
    r"split\s*screen|composite|blend(?:ing)?|crossfade|"
    r"face\s*melt|warp(?:ing)?)\b",
    re.IGNORECASE,
)


class KeyframeRewriteCritic(CriticLoop):
    """IP3: Deterministic critic for Flash-rewritten keyframe prompts.

    Args:
        bible: Global bible dict (for character/location lookups).
        shot: Shot plan dict (for asset_data, spatial_data).
        max_attempts: Default 2 (deterministic auto-fix, one retry).
        experience_pool_dir: Path for JSONL logging.
        shot_id: Shot identifier for logging.
    """

    def __init__(
        self,
        bible: dict,
        shot: dict,
        max_attempts: int = 2,
        experience_pool_dir=None,
        shot_id: str = "",
    ):
        super().__init__(
            name="keyframe_rewrite",
            max_attempts=max_attempts,
            experience_pool_dir=experience_pool_dir,
            shot_id=shot_id,
        )
        self.bible = bible
        self.shot = shot

    def evaluate(self, artifact: str, context: dict) -> list[Dimension]:
        """Evaluate rewritten keyframe prompt against 5 dimensions.

        Args:
            artifact: The rewritten prompt text from Flash.
            context: Unused.

        Returns:
            List of 5 Dimension results.
        """
        dims = []
        prompt_lower = artifact.lower()
        first_30_words = " ".join(artifact.split()[:30]).lower()

        # ── SUBJECT_PRIMACY ───────────────────────────────────────
        # Check that at least one character display_name or char_id
        # appears in the first 30 words of the prompt.
        asset_data = self.shot.get("asset_data", {})
        characters = asset_data.get("characters", [])
        bible_chars = self.bible.get("characters", {})

        found_subject = False
        for char in characters:
            char_id = char.get("char_id", str(char)) if isinstance(char, dict) else str(char)
            bible_char = bible_chars.get(char_id, {})
            display_name = bible_char.get("display_name", char_id)

            if display_name and display_name.lower() in first_30_words:
                found_subject = True
                break
            if char_id.lower() in first_30_words:
                found_subject = True
                break

        # If no characters in shot, pass by default (ENV shot)
        if not characters:
            found_subject = True

        dims.append(Dimension(
            name="SUBJECT_PRIMACY",
            severity=Severity.SOFT,
            passed=found_subject,
            message="" if found_subject else "No character identity found in first 30 words",
            failure_mode=None if found_subject else FailureMode.IDENTITY_DRIFT,
        ))

        # ── ENVIRONMENT_LOCK ──────────────────────────────────────
        # Check that key environment terms from bible location survive.
        location_id = asset_data.get("location_id", "")
        bible_locs = self.bible.get("locations", {})
        bible_loc = bible_locs.get(location_id, {})
        visual_desc = bible_loc.get("visual_description", "") or bible_loc.get("description", "")

        env_passed = True
        env_message = ""
        if visual_desc:
            # Extract significant words (4+ chars) from bible description
            key_words = set(
                w.lower() for w in re.findall(r"\b\w{4,}\b", visual_desc)
            ) - {"this", "that", "with", "from", "into", "have", "been", "were",
                  "they", "their", "there", "some", "most", "very", "each", "when",
                  "where", "which", "about", "these", "those", "being"}
            if key_words:
                overlap = sum(1 for w in key_words if w in prompt_lower)
                overlap_ratio = overlap / len(key_words) if key_words else 1.0
                env_passed = overlap_ratio >= 0.2  # At least 20% overlap
                if not env_passed:
                    env_message = f"Only {overlap}/{len(key_words)} environment keywords survive ({overlap_ratio:.0%})"

        dims.append(Dimension(
            name="ENVIRONMENT_LOCK",
            severity=Severity.SOFT,
            passed=env_passed,
            message=env_message,
            failure_mode=None if env_passed else FailureMode.BACKGROUND_CONTAMINATION,
        ))

        # ── SPATIAL_SURVIVAL ──────────────────────────────────────
        # Check that spatial block terms survive rewrite.
        spatial_data = self.shot.get("prompt_data", {}).get("spatial_data", {})
        spatial_passed = True
        spatial_message = ""
        if spatial_data:
            spatial_terms = []
            if spatial_data.get("camera_side"):
                spatial_terms.append(spatial_data["camera_side"].lower())
            if spatial_data.get("subject_position"):
                spatial_terms.append(spatial_data["subject_position"].lower())

            missing = [t for t in spatial_terms if t and t not in prompt_lower]
            spatial_passed = len(missing) == 0
            if not spatial_passed:
                spatial_message = f"Spatial terms missing from rewrite: {', '.join(missing)}"

        dims.append(Dimension(
            name="SPATIAL_SURVIVAL",
            severity=Severity.SOFT,
            passed=spatial_passed,
            message=spatial_message,
            failure_mode=None if spatial_passed else FailureMode.COVERAGE_GEOMETRY_BROKEN,
        ))

        # ── ARCHETYPE_DETOX ───────────────────────────────────────
        found_archetypes = []
        for pattern, _replacement in _ARCHETYPE_PATTERNS:
            match = pattern.search(artifact)
            if match:
                found_archetypes.append(match.group())

        arch_passed = len(found_archetypes) == 0
        dims.append(Dimension(
            name="ARCHETYPE_DETOX",
            severity=Severity.HARD,
            passed=arch_passed,
            message="" if arch_passed else f"Archetype triggers: {', '.join(found_archetypes)}",
            failure_mode=None if arch_passed else FailureMode.UNKNOWN,
        ))

        # ── VFX_BLACKLIST ─────────────────────────────────────────
        vfx_matches = _VFX_BLACKLIST.findall(artifact)
        vfx_passed = len(vfx_matches) == 0
        dims.append(Dimension(
            name="VFX_BLACKLIST",
            severity=Severity.HARD,
            passed=vfx_passed,
            message="" if vfx_passed else f"VFX blacklist terms: {', '.join(set(v.lower() for v in vfx_matches))}",
            failure_mode=None if vfx_passed else FailureMode.ANATOMY_FACE_MERGE,
        ))

        return dims

    def auto_fix(self, artifact: str, failed_dims: list[Dimension], context: dict) -> str:
        """Auto-fix ARCHETYPE_DETOX and VFX_BLACKLIST violations."""
        text = artifact
        for dim in failed_dims:
            if dim.name == "ARCHETYPE_DETOX":
                for pattern, replacement in _ARCHETYPE_PATTERNS:
                    text = pattern.sub(replacement, text)
                dim.auto_fixed = True
            elif dim.name == "VFX_BLACKLIST":
                # Remove VFX blacklist terms entirely
                text = _VFX_BLACKLIST.sub("", text)
                # Clean up double spaces
                text = re.sub(r"\s{2,}", " ", text).strip()
                dim.auto_fixed = True
        return text


# ──────────────────────────────────────────────────────────────────────────
# Structured-output adapter (Phase 2.5 — wired to real CriticLoop)
# ──────────────────────────────────────────────────────────────────────────

def keyframe_rewrite_critic(
    image=None,
    reference=None,
    structured_output: bool = False,
    artifact: str | None = None,
    bible: dict | None = None,
    shot: dict | None = None,
    **kwargs,
):
    """Structured-output adapter for the keyframe rewrite critic.

    Runs the real KeyframeRewriteCritic class and translates its
    Dimension list to a FailureMode-typed dict for the feedback agent.

    Required for structured_output=True: artifact (prompt text), bible, shot.
    """
    if not structured_output:
        raise NotImplementedError(
            "Use KeyframeRewriteCritic class directly for non-structured calls."
        )

    if artifact is None or bible is None or shot is None:
        raise ValueError(
            "keyframe_rewrite_critic(structured_output=True) requires "
            "artifact (prompt text), bible, and shot."
        )

    critic = KeyframeRewriteCritic(
        bible=bible,
        shot=shot,
        shot_id=kwargs.get("shot_id", ""),
    )
    _, result = critic.run(artifact, context={})

    # Translate Dimension list → FailureMode (first hard failure wins)
    failure_mode = FailureMode.NONE.value
    for dim in result.hard_failures:
        if dim.name == "VFX_BLACKLIST":
            failure_mode = FailureMode.ANATOMY_FACE_MERGE.value
            break
        elif dim.name == "ARCHETYPE_DETOX":
            failure_mode = FailureMode.UNKNOWN.value
            break

    # PC-1 keep-bias: soft failures alone do NOT flip passed to False
    return {
        "passed": result.passed,
        "score": 1.0 if result.passed else 0.0,
        "failure_mode": failure_mode,
        "evidence": "; ".join(d.message for d in result.failed_dimensions if d.message),
        "dimensions": [
            {"name": d.name, "passed": d.passed, "severity": d.severity.value}
            for d in result.dimensions
        ],
    }
