"""
plan_pass_critic.py — IP1: Plan Pass Skeleton Critic.

Runs AFTER Stage 2 skeleton generation in ingest_pipeline.py.
Five deterministic dimensions — zero LLM calls.

Reuses existing codebase validators:
  - validate_sterility() from `recoil.pipeline._lib.jit_prompt`
  - auto_fix_tokens() from `recoil.pipeline._lib.jit_prompt`

Dimensions:
  STERILITY           — No character names in environment_line (HARD)
  SHOT_GRAMMAR        — Skeleton has required fields populated (SOFT)
  VERB_STRENGTH       — action_line uses specific verbs, not vague ones (SOFT)
  EMOTION_SPECIFICITY — emotion_line avoids generic labels (SOFT)
  SPATIAL_COHERENCE   — Camera side doesn't flip from adjacent shot (SOFT)
"""

import logging
import re

from recoil.core.critic import CriticLoop, Dimension, Severity

from . import FailureMode

logger = logging.getLogger(__name__)

# Vague emotion labels that are too generic
_VAGUE_EMOTIONS = {
    "sad", "happy", "angry", "scared", "nervous", "worried", "excited",
}

# Weak/vague verbs in action_line
_WEAK_VERBS = re.compile(
    r"\b(?:is|are|was|were|goes?|moves?|walks?|comes?|gets?|does?|looks?|seems?|feels?)\b",
    re.IGNORECASE,
)

# Required skeleton fields
_REQUIRED_FIELDS = ["subject_line", "environment_line"]
_OPTIONAL_FIELDS = ["action_line", "emotion_line"]


class PlanPassCritic(CriticLoop):
    """IP1: Deterministic critic for Stage 2 shot skeletons.

    Args:
        bible: Global bible dict.
        all_shots: All shots in the episode (for spatial coherence check).
        max_attempts: Default 2 (auto-fix sterility, one retry).
        experience_pool_dir: Path for JSONL logging.
        shot_id: Shot identifier for logging.
    """

    def __init__(
        self,
        bible: dict,
        all_shots: list[dict] | None = None,
        max_attempts: int = 2,
        experience_pool_dir=None,
        shot_id: str = "",
    ):
        super().__init__(
            name="plan_pass",
            max_attempts=max_attempts,
            experience_pool_dir=experience_pool_dir,
            shot_id=shot_id,
        )
        self.bible = bible
        self.all_shots = all_shots or []

    def evaluate(self, artifact: dict, context: dict) -> list[Dimension]:
        """Evaluate a shot dict against 5 dimensions.

        Args:
            artifact: Shot dict (from ep_NNN_plan.json).
            context: Unused.

        Returns:
            List of 5 Dimension results.
        """
        dims = []
        skeleton = artifact.get("prompt_data", {}).get("prompt_skeleton", {})
        shot_id = artifact.get("shot_id", "")

        # ── STERILITY ─────────────────────────────────────────────
        # Reuse existing validate_sterility from jit_prompt
        try:
            from recoil.pipeline._lib.jit_prompt import validate_sterility
            violations = validate_sterility(artifact, self.bible)
            sterility_passed = len(violations) == 0
        except Exception as e:
            logger.debug("validate_sterility unavailable: %s", e)
            violations = []
            sterility_passed = True  # Graceful degradation

        dims.append(Dimension(
            name="STERILITY",
            severity=Severity.HARD,
            passed=sterility_passed,
            message="" if sterility_passed else f"Sterility violations: {'; '.join(violations)}",
            failure_mode=None if sterility_passed else FailureMode.UNKNOWN,
        ))

        # ── SHOT_GRAMMAR ─────────────────────────────────────────
        missing_required = [f for f in _REQUIRED_FIELDS if not skeleton.get(f, "").strip()]
        empty_optional = [f for f in _OPTIONAL_FIELDS if not skeleton.get(f, "").strip()]
        grammar_passed = len(missing_required) == 0
        grammar_message = ""
        if missing_required:
            grammar_message = f"Missing required fields: {', '.join(missing_required)}"
        elif empty_optional:
            grammar_message = f"Empty optional fields: {', '.join(empty_optional)}"

        dims.append(Dimension(
            name="SHOT_GRAMMAR",
            severity=Severity.SOFT,
            passed=grammar_passed,
            message=grammar_message,
            failure_mode=None if grammar_passed else FailureMode.UNKNOWN,
        ))

        # ── VERB_STRENGTH ────────────────────────────────────────
        action_line = skeleton.get("action_line", "") or skeleton.get("subject_line", "")
        weak_matches = _WEAK_VERBS.findall(action_line) if action_line else []
        # Count ratio of weak verbs to total words
        total_words = len(action_line.split()) if action_line else 0
        weak_ratio = len(weak_matches) / max(total_words, 1)
        verb_passed = weak_ratio < 0.3  # Less than 30% weak verbs
        dims.append(Dimension(
            name="VERB_STRENGTH",
            severity=Severity.SOFT,
            passed=verb_passed,
            message="" if verb_passed else f"Weak verb ratio {weak_ratio:.0%}: {', '.join(set(v.lower() for v in weak_matches))}",
            failure_mode=None if verb_passed else FailureMode.UNKNOWN,
        ))

        # ── EMOTION_SPECIFICITY ──────────────────────────────────
        emotion_line = skeleton.get("emotion_line", "")
        emotion_passed = True
        emotion_message = ""
        if emotion_line:
            # Check if the emotion_line is ONLY a vague label
            emotion_words = set(emotion_line.lower().strip().split())
            if emotion_words and emotion_words.issubset(_VAGUE_EMOTIONS | {"a", "an", "the", "and", "of", "with"}):
                emotion_passed = False
                emotion_message = f"Vague emotion label: '{emotion_line.strip()}'"

        dims.append(Dimension(
            name="EMOTION_SPECIFICITY",
            severity=Severity.SOFT,
            passed=emotion_passed,
            message=emotion_message,
            failure_mode=None if emotion_passed else FailureMode.UNKNOWN,
        ))

        # ── SPATIAL_COHERENCE ────────────────────────────────────
        # 180-degree rule over the DERIVED axis fields (REC-180/REC-181). Reads the
        # SIBLING spatial_data (not prompt_data.spatial_data — the old dead path) for
        # both shots. An UNLICENSED crossing = adjacent shots in the same scene AND the
        # same axis_segment_id whose camera_side flips while cut_relation == "consistent".
        # A flip WITH a licensing cut_relation (intentional_jump / re_establish /
        # neutral_pivot) is legal and passes.
        spatial_data = artifact.get("spatial_data", {})
        camera_side = spatial_data.get("camera_side", "")
        spatial_passed = True
        spatial_message = ""

        if camera_side and self.all_shots:
            shot_idx = next(
                (i for i, s in enumerate(self.all_shots) if s.get("shot_id") == shot_id),
                -1,
            )
            if shot_idx > 0:
                prev_shot = self.all_shots[shot_idx - 1]
                prev_spatial = prev_shot.get("spatial_data", {})
                prev_side = prev_spatial.get("camera_side", "")
                prev_scene = prev_shot.get("scene_index", -1)
                this_scene = artifact.get("scene_index", -1)
                seg = spatial_data.get("axis_segment_id", 0)
                prev_seg = prev_spatial.get("axis_segment_id", 0)
                cut_relation = spatial_data.get("cut_relation", "consistent")
                # Only evaluate REC-180-derived plans: both shots must carry the derived axis
                # fields. Legacy/partially-migrated artifacts (sibling spatial_data without
                # cut_relation) are skipped to avoid false-positive crossings.
                has_axis = (
                    "cut_relation" in spatial_data and "axis_segment_id" in spatial_data
                    and "cut_relation" in prev_spatial and "axis_segment_id" in prev_spatial
                )

                # Only evaluate when both sides are real A/B values, same scene, same segment.
                if (
                    has_axis
                    and prev_side in {"A", "B"}
                    and camera_side in {"A", "B"}
                    and prev_scene == this_scene
                    and prev_seg == seg
                    and camera_side != prev_side
                    and cut_relation == "consistent"
                ):
                    spatial_passed = False
                    spatial_message = (
                        f"Unlicensed 180 crossing: {prev_side}->{camera_side} "
                        f"within segment {seg} ({shot_id})"
                    )

        dims.append(Dimension(
            name="SPATIAL_COHERENCE",
            severity=Severity.SOFT,
            passed=spatial_passed,
            message=spatial_message,
            failure_mode=None if spatial_passed else FailureMode.COVERAGE_GEOMETRY_BROKEN,
        ))

        return dims

    def auto_fix(self, artifact: dict, failed_dims: list[Dimension], context: dict) -> dict:
        """Auto-fix STERILITY violations using existing auto_fix_tokens."""
        for dim in failed_dims:
            if dim.name == "STERILITY":
                try:
                    from recoil.pipeline._lib.jit_prompt import auto_fix_tokens
                    artifact = auto_fix_tokens(artifact, self.bible)
                    dim.auto_fixed = True
                except Exception as e:
                    logger.debug("auto_fix_tokens unavailable: %s", e)
        return artifact


# ──────────────────────────────────────────────────────────────────────────
# Structured-output adapter (Phase 2.5 — wired to real CriticLoop)
# ──────────────────────────────────────────────────────────────────────────

def plan_pass_critic(
    image=None,
    reference=None,
    structured_output: bool = False,
    shot: dict | None = None,
    bible: dict | None = None,
    all_shots: list | None = None,
    **kwargs,
):
    """Structured-output adapter for the plan pass critic.

    Runs the real PlanPassCritic class and translates its Dimension list to
    a FailureMode-typed dict. Plan-pass is script-side, so any failure
    returns UNKNOWN (image-side failure_modes don't apply).

    Required for structured_output=True: shot (dict), bible (dict).
    """
    if not structured_output:
        raise NotImplementedError(
            "Use PlanPassCritic class directly for non-structured calls."
        )

    if shot is None or bible is None:
        raise ValueError(
            "plan_pass_critic(structured_output=True) requires "
            "shot (dict) and bible (dict)."
        )

    critic = PlanPassCritic(
        bible=bible,
        all_shots=all_shots or [],
        shot_id=kwargs.get("shot_id", shot.get("shot_id", "")),
    )
    _, result = critic.run(shot, context={})

    failure_mode = (
        FailureMode.NONE.value if result.passed else FailureMode.UNKNOWN.value
    )
    return {
        "passed": result.passed,
        "score": 1.0 if result.passed else 0.0,
        "failure_mode": failure_mode,
        "evidence": "; ".join(d.message for d in result.failed_dimensions if d.message),
        "dimensions": [
            {"name": d.name, "passed": d.passed, "severity": d.severity.value}
            for d in result.dimensions
        ],
    }
