"""Validation for LLM-authored scene axis plans (REC-180).

`scene_plan_errors` + `noncontiguous_scenes` are the SHARED predicates used by
`validate_axis_plans` (retry-loop gate), `sanitize_axis_plans` (final-degrade drop), AND
`axis_propagation.propagate_axis` (neutral-fallback decision) — so all three agree on what
"invalid/unsupported" means (incl. duplicate effective shot indices and intercut scenes).
`degenerate_variation` is a post-materialization advisory. All pure, no I/O.
"""
from __future__ import annotations

import re

from recoil.pipeline._lib.render_schema import (
    AxisKind,
    AxisTransitionKind,
    ScreenDirection,
)

_SHOT_INDEX_RE = re.compile(r"_SH0*(\d+)[A-Z]*$")


def shot_index_of(shot) -> int:
    """1-based shot index — the SINGLE extractor shared by validation and propagation.
    Prefers the authoritative `shot_index` field; falls back to parsing `shot_id`
    (EP001_SH07 -> 7, EP001_SH07A -> 7) only for legacy records where shot_index == 0."""
    si = getattr(shot, "shot_index", 0) or 0
    if si >= 1:
        return si
    shot_id = getattr(shot, "shot_id", None)
    m = _SHOT_INDEX_RE.search(shot_id) if shot_id else None
    if m is None:
        raise ValueError(f"un-indexable shot {shot_id!r} (no shot_index and unparseable id)")
    idx = int(m.group(1))
    if idx < 1:
        raise ValueError(f"shot_id {shot_id!r} parsed to non-positive index {idx}")
    return idx


def noncontiguous_scenes(shots) -> set[int]:
    """scene_index values that re-appear after a DIFFERENT scene (an intercut), given shots in
    episode order. Not produced by current Stage-0 (scene_index is monotonic) but detected
    identically by validate and propagate so they never disagree."""
    bad: set[int] = set()
    seen: set[int] = set()
    prev = None
    for s in shots:
        if s.scene_index != prev:
            if s.scene_index in seen:
                bad.add(s.scene_index)
            seen.add(s.scene_index)
        prev = s.scene_index
    return bad


def _ordered_scene_indices(shots) -> dict[int, list[int]]:
    """scene_index -> ORDERED list of that scene's shot indices (via the shared extractor)."""
    out: dict[int, list[int]] = {}
    for s in shots:
        out.setdefault(s.scene_index, []).append(shot_index_of(s))
    return out


def scene_plan_errors(scene_plan, scene_shot_indices) -> list[str]:
    """Structural errors for ONE scene's plan, given that scene's shot indices (list or set).

    Includes the duplicate-effective-index check (SH07 + SH07A both -> 7) so a single predicate
    decides validity for validate, sanitize, AND propagate. Empty list == the plan is usable.
    """
    idx_list = list(scene_shot_indices)
    idx_set = set(idx_list)
    errs: list[str] = []

    if len(idx_list) != len(idx_set):
        errs.append("scene has duplicate effective shot indices (axis transitions would be ambiguous)")

    anchors = [("initial_anchor", scene_plan.initial_anchor)]
    anchors += [
        (f"transition[{i}].new_anchor", t.new_anchor)
        for i, t in enumerate(scene_plan.transitions)
        if t.new_anchor is not None
    ]
    _LATERAL = {ScreenDirection.LEFT_TO_RIGHT, ScreenDirection.RIGHT_TO_LEFT}
    _ON_AXIS = {ScreenDirection.CENTER, ScreenDirection.TOWARD_CAMERA, ScreenDirection.AWAY_FROM_CAMERA}
    for label, anchor in anchors:
        # A non-neutral line of action is LATERAL; a neutral anchor is ON-AXIS. Enforcing BOTH
        # keeps the validator and projector in agreement (project_direction maps lateral->center
        # for a neutral anchor, so neutral+lateral would validate yet materialize contradictory data).
        if anchor.kind != AxisKind.NEUTRAL and anchor.reference_direction not in _LATERAL:
            errs.append(
                f"{label}: non-neutral anchor (kind={anchor.kind.value}) reference_direction="
                f"{anchor.reference_direction.value!r} must be lateral (left-to-right/right-to-left)"
            )
        if anchor.kind == AxisKind.NEUTRAL and anchor.reference_direction not in _ON_AXIS:
            errs.append(
                f"{label}: neutral anchor reference_direction={anchor.reference_direction.value!r} "
                f"must be on-axis (center/toward-camera/away-from-camera)"
            )

    first_shot = min(idx_set) if idx_set else None
    seen: set[int] = set()
    for i, t in enumerate(scene_plan.transitions):
        loc = f"transition[{i}] (before_shot_index={t.before_shot_index})"
        if t.before_shot_index not in idx_set:
            errs.append(f"{loc}: before_shot_index is not a shot in this scene")
        elif t.before_shot_index == first_shot:
            errs.append(f"{loc}: a transition cannot apply to the scene's first shot "
                        f"(no prior cut to license — encode initial state in initial_anchor)")
        if t.before_shot_index in seen:
            errs.append(f"{loc}: duplicate before_shot_index")
        seen.add(t.before_shot_index)
        if not (t.reason or "").strip():
            errs.append(f"{loc}: empty reason (required)")
        if t.kind == AxisTransitionKind.RE_ESTABLISH and t.new_anchor is None:
            errs.append(f"{loc}: re_establish requires new_anchor")
        if t.kind in (AxisTransitionKind.INTENTIONAL_JUMP, AxisTransitionKind.NEUTRAL_PIVOT) and t.new_anchor is not None:
            errs.append(f"{loc}: {t.kind.value} must not carry new_anchor")
    return errs


def validate_axis_plans(creative, ct) -> list[str]:
    """Structural validation of creative.axis_plans against the camera-tested scenes.

    A scene with NO authored plan is NOT an error (propagate_axis neutral-fallbacks it;
    degenerate_variation warns on monotony). Only malformed PRESENT plans, keys for nonexistent
    scenes, intercut scenes, and duplicate-index scenes trigger retry pressure.
    """
    lists = _ordered_scene_indices(ct.shots)
    noncontig = noncontiguous_scenes(ct.shots)
    errs: list[str] = []
    for si in sorted(creative.axis_plans):
        if si not in lists:
            errs.append(f"axis_plans has key {si} for a nonexistent scene")
            continue
        if si in noncontig:
            errs.append(f"scene {si} is non-contiguous (intercut) — unsupported")
            continue
        errs.extend(f"scene {si} {e}" for e in scene_plan_errors(creative.axis_plans[si], lists[si]))
    return errs


def sanitize_axis_plans(creative, ct) -> list[int]:
    """Drop every invalid/unsupported scene plan from creative.axis_plans IN PLACE (final degrade).
    Uses the SAME predicates as validate_axis_plans. Returns the sorted dropped scene_indices.
    """
    lists = _ordered_scene_indices(ct.shots)
    noncontig = noncontiguous_scenes(ct.shots)
    dropped: list[int] = []
    for si in list(creative.axis_plans):
        if si not in lists or si in noncontig or scene_plan_errors(creative.axis_plans[si], lists[si]):
            del creative.axis_plans[si]
            dropped.append(si)
    return sorted(dropped)


def degenerate_variation(plan) -> list[str]:
    """Post-materialization advisory: a SCENE (>=4 shots) whose shots are all one screen_direction."""
    by_scene: dict[int, list] = {}
    for s in plan.shots:
        by_scene.setdefault(s.scene_index, []).append(s)
    warnings: list[str] = []
    authored = getattr(plan, "axis_plans", {}) or {}
    for scene_index, shots in sorted(by_scene.items()):
        # Only warn for scenes that ACTUALLY had an authored axis plan — a scene with no plan is
        # intentionally neutral (the recoverable degrade path), not an LLM monotony failure.
        if scene_index in authored and len(shots) >= 4:
            dirs = {s.spatial_data.screen_direction for s in shots}
            if len(dirs) <= 1:
                v = next(iter(dirs)).value if dirs else "center"
                warnings.append(
                    f"scene {scene_index}: spatial staging not varied — all {len(shots)} "
                    f"shots screen_direction={v!r}"
                )
    return warnings
