"""Canonical authored-prose validation for REC-72 prompt authoring."""

from __future__ import annotations

import logging
import re
from dataclasses import dataclass
from enum import Enum
from typing import Any

logger = logging.getLogger(__name__)


class Severity(Enum):
    BLOCK = "BLOCK"
    WARN = "WARN"
    INFO = "INFO"


@dataclass
class ValidationResult:
    severity: Severity
    pass_id: str
    check: str
    message: str


_IMAGEN_RE = re.compile(r"@Image\d+")
_TIMECODE_RE = re.compile(
    r"\[(?P<sm>\d+):(?P<ss>\d{2})\s*-\s*(?P<em>\d+):(?P<es>\d{2})\]"
)
_CAMERA_RE = re.compile(
    r"\b("
    r"camera|lens|push(?:es)?|pull(?:s)?|dolly|track(?:s|ing)?|handheld|"
    r"lock-?off|whip|rack(?:-|\s)?focus|crane|pan(?:s|ning)?|tilt(?:s|ing)?|"
    r"zoom|close-?up|wide|over-?the-?shoulder|ots|ecu|cu|ms|ws"
    r")\b",
    re.IGNORECASE,
)
_PERFORMANCE_RE = re.compile(
    r"\b("
    r"breath|jaw|eyes?|face|mouth|shoulders?|hands?|fingers?|body|posture|"
    r"flinch(?:es)?|trembl(?:es|ing)?|brac(?:es|ing)|hesitat(?:es|ing|ion)|"
    r"stare(?:s|ing)?|glance(?:s|ing)?|grimace(?:s)?|swallow(?:s|ing)?|"
    r"smile(?:s|ing)?|tear(?:s)?|shiver(?:s|ing)?|voice|expression"
    r")\b",
    re.IGNORECASE,
)


def verify_authored_prose(
    authored_text: str,
    primitive: Any,
    strategy: Any,
) -> list[ValidationResult]:
    """Validate raw author prose before deterministic name/ref binding.

    The gate blocks only objective failures. Subjective creative signals are
    emitted as WARN and must never stop dispatch.
    """

    try:
        return _verify_authored_prose(authored_text, primitive, strategy)
    except Exception as exc:
        primitive_id = _primitive_id(primitive)
        logger.exception("prose_verify internal failure for %s", primitive_id)
        return [
            ValidationResult(
                Severity.BLOCK,
                primitive_id,
                "prose_verify_internal_error",
                f"Authored prose verifier failed without raising to caller: {exc}",
            )
        ]


def _verify_authored_prose(
    authored_text: str,
    primitive: Any,
    strategy: Any,
) -> list[ValidationResult]:
    results: list[ValidationResult] = []
    prose = authored_text if isinstance(authored_text, str) else ""
    pass_id = _primitive_id(primitive)

    if not prose.strip():
        results.append(
            ValidationResult(
                Severity.BLOCK,
                pass_id,
                "prose_verify_empty",
                "Authored prose is empty",
            )
        )

    leaked = _IMAGEN_RE.findall(prose)
    if leaked:
        results.append(
            ValidationResult(
                Severity.BLOCK,
                pass_id,
                "prose_verify_imagen_leak",
                "Authored prose contains @ImageN literal(s) "
                f"{sorted(set(leaked))} — the author must never emit ref tokens "
                "(they are bound deterministically later)",
            )
        )

    for missing in _missing_strategy_inputs(primitive, strategy):
        results.append(
            ValidationResult(
                Severity.BLOCK,
                pass_id,
                "prose_verify_required_input",
                f"Strategy {getattr(strategy, 'name', strategy)!r} is missing "
                f"required primitive input {missing!r}",
            )
        )

    for char_id in _char_ids(primitive):
        if not _contains_character_name(prose, char_id):
            results.append(
                ValidationResult(
                    Severity.BLOCK,
                    pass_id,
                    "prose_verify_coverage_char",
                    f"Authored prose is missing char_id {char_id!r} from the skeleton",
                )
            )

    spans = _parse_timecodes(prose)
    invalid_timecodes = [span for span in spans if span[1] <= span[0]]
    for start_s, end_s, literal in invalid_timecodes:
        results.append(
            ValidationResult(
                Severity.BLOCK,
                pass_id,
                "prose_verify_timecodes",
                f"Authored prose has invalid timecode {literal!r} "
                f"({start_s:.2f}s-{end_s:.2f}s)",
            )
        )

    expected_beats = _expected_beat_count(primitive)
    if expected_beats is not None:
        if len(spans) != expected_beats:
            results.append(
                ValidationResult(
                    Severity.BLOCK,
                    pass_id,
                    "prose_verify_segment_count",
                    f"Authored prose has {len(spans)} timecoded segment(s) but "
                    f"the skeleton has {expected_beats}",
                )
            )
        elif expected_beats and not spans:
            results.append(
                ValidationResult(
                    Severity.BLOCK,
                    pass_id,
                    "prose_verify_timecodes",
                    "Authored prose has no [m:ss-m:ss] timecodes but the "
                    f"skeleton has {expected_beats} segment(s)",
                )
            )

    target_duration = _target_duration(primitive)
    valid_spans = [span for span in spans if span[1] > span[0]]
    if target_duration is not None and valid_spans:
        authored_duration = sum(end_s - start_s for start_s, end_s, _ in valid_spans)
        tolerance = max(0.5, min(1.0, target_duration * 0.05))
        if abs(authored_duration - target_duration) > tolerance:
            results.append(
                ValidationResult(
                    Severity.BLOCK,
                    pass_id,
                    "prose_verify_duration",
                    f"Authored prose timecodes sum to {authored_duration:.2f}s "
                    f"but target duration is {target_duration:.2f}s",
                )
            )

    results.extend(_subjective_warnings(prose, primitive, strategy, spans))
    return results


def _primitive_id(primitive: Any) -> str:
    return str(
        getattr(primitive, "shot_id", None)
        or getattr(primitive, "pass_id", None)
        or "authored_prose"
    )


def _missing_strategy_inputs(primitive: Any, strategy: Any) -> list[str]:
    required = list(getattr(strategy, "required_inputs", None) or [])
    if not required:
        return []
    try:
        from recoil.pipeline._lib.author_strategies import missing_required_inputs

        return missing_required_inputs(primitive, strategy)
    except Exception:
        missing: list[str] = []
        for path in required:
            if not _truthy_path(primitive, path):
                missing.append(path)
        return missing


def _truthy_path(primitive: Any, dotted_path: str) -> bool:
    value = primitive
    for part in dotted_path.split("."):
        if isinstance(value, dict):
            value = value.get(part)
        else:
            value = getattr(value, part, None)
        if value is None:
            return False
    if isinstance(value, (list, tuple, dict, set)):
        return bool(value)
    return value != ""


def _char_ids(primitive: Any) -> list[str]:
    value = getattr(primitive, "char_ids", None)
    if not value and isinstance(primitive, dict):
        value = primitive.get("char_ids")
    return [str(item) for item in (value or []) if str(item).strip()]


def _contains_character_name(prose: str, char_id: str) -> bool:
    normalized = re.sub(r"[_\-]+", " ", char_id).strip()
    variants = {
        char_id,
        char_id.lower(),
        char_id.upper(),
        char_id.title(),
        normalized,
        normalized.lower(),
        normalized.upper(),
        normalized.title(),
    }
    parts = [part for part in re.split(r"[_\-\s]+", char_id) if part]
    if parts:
        variants.add(parts[-1])
        variants.add(parts[-1].lower())
        variants.add(parts[-1].upper())
        variants.add(parts[-1].title())
    for prefix in ("char_", "character_"):
        lowered = char_id.lower()
        if lowered.startswith(prefix):
            tail = char_id[len(prefix) :]
            variants.update({tail, tail.lower(), tail.upper(), tail.title()})

    for variant in sorted(variants, key=len, reverse=True):
        if not variant:
            continue
        if re.search(rf"\b{re.escape(variant)}\b", prose, flags=re.IGNORECASE):
            return True
    return False


def _parse_timecodes(prose: str) -> list[tuple[float, float, str]]:
    spans: list[tuple[float, float, str]] = []
    for match in _TIMECODE_RE.finditer(prose):
        start_s = int(match.group("sm")) * 60 + int(match.group("ss"))
        end_s = int(match.group("em")) * 60 + int(match.group("es"))
        spans.append((float(start_s), float(end_s), match.group(0)))
    return spans


def _expected_beat_count(primitive: Any) -> int | None:
    value = getattr(primitive, "timing_segments", None)
    if value is None and isinstance(primitive, dict):
        value = primitive.get("timing_segments")
    if not value:
        return None
    try:
        return len(value)
    except TypeError:
        return None


def _target_duration(primitive: Any) -> float | None:
    value = getattr(primitive, "target_editorial_duration_s", None)
    if value is None and isinstance(primitive, dict):
        value = primitive.get("target_editorial_duration_s")
    try:
        duration = float(value)
    except (TypeError, ValueError):
        return None
    return duration if duration > 0 else None


def _subjective_warnings(
    prose: str,
    primitive: Any,
    strategy: Any,
    spans: list[tuple[float, float, str]],
) -> list[ValidationResult]:
    strategy_name = str(getattr(strategy, "name", "") or "")
    if strategy_name != "directed_prose" or not prose.strip():
        return []

    pass_id = _primitive_id(primitive)
    beat_texts = _beat_texts(prose, spans)
    results: list[ValidationResult] = []
    for idx, beat in enumerate(beat_texts, start=1):
        if not _CAMERA_RE.search(beat):
            results.append(
                ValidationResult(
                    Severity.WARN,
                    pass_id,
                    "prose_verify_camera",
                    f"Beat {idx} may be missing motivated camera language",
                )
            )
        if not _PERFORMANCE_RE.search(beat):
            results.append(
                ValidationResult(
                    Severity.WARN,
                    pass_id,
                    "prose_verify_performance",
                    f"Beat {idx} may be missing performance/body behavior",
                )
            )
    return results


def _beat_texts(prose: str, spans: list[tuple[float, float, str]]) -> list[str]:
    if not spans:
        lines = [line.strip() for line in prose.splitlines() if line.strip()]
        return lines or [prose]

    matches = list(_TIMECODE_RE.finditer(prose))
    beats: list[str] = []
    for idx, match in enumerate(matches):
        start = match.end()
        end = matches[idx + 1].start() if idx + 1 < len(matches) else len(prose)
        beats.append(prose[start:end].strip())
    return beats or [prose]
