"""Tests for axis_validation (REC-180): validate / sanitize / degenerate_variation."""
from __future__ import annotations

from types import SimpleNamespace

from recoil.pipeline._lib.render_schema import (
    CreativeEpisodeOutput,
    ScreenDirection,
)
from recoil.pipeline.orchestrator.axis_validation import (
    degenerate_variation,
    sanitize_axis_plans,
    validate_axis_plans,
)


def _creative(scene_of: dict[int, int], axis_plans: dict) -> CreativeEpisodeOutput:
    shots = [
        {
            "shot_index": i,
            "prompt_skeleton": {"subject_line": "a", "environment_line": "b",
                                "action_line": "c", "motion_line": "d e", "emotion_line": "f"},
            "shot_type": "CU", "target_editorial_duration_s": 4,
        }
        for i in sorted(scene_of)
    ]
    return CreativeEpisodeOutput.model_validate(
        {"episode_id": "EP001", "total_shots": len(shots), "shots": shots, "axis_plans": axis_plans}
    )


def _ct(scene_of: dict[int, int]):
    return SimpleNamespace(shots=[SimpleNamespace(shot_index=i, scene_index=sc)
                                  for i, sc in sorted(scene_of.items())])


def _motion(direction="left-to-right"):
    return {"initial_anchor": {"kind": "motion", "reference_direction": direction}}


def test_missing_scene_plan_is_not_error():
    # a scene with no authored plan is recoverable (neutral fallback), NOT a retry error
    scene_of = {1: 1, 2: 1, 3: 2}  # scenes 1 and 2; only scene 1 has a plan
    errs = validate_axis_plans(_creative(scene_of, {1: _motion()}), _ct(scene_of))
    assert errs == [], errs


def test_extra_scene_key():
    scene_of = {1: 1}
    errs = validate_axis_plans(_creative(scene_of, {1: _motion(), 9: _motion()}), _ct(scene_of))
    assert any("nonexistent scene" in e for e in errs), errs


def test_bad_before_shot_index():
    scene_of = {1: 1, 2: 2}
    ap = {1: {"initial_anchor": {"kind": "motion", "reference_direction": "left-to-right"},
              "transitions": [{"before_shot_index": 2, "kind": "intentional_jump", "reason": "x"}]},
          2: _motion()}
    errs = validate_axis_plans(_creative(scene_of, ap), _ct(scene_of))
    assert any("not a shot in this scene" in e for e in errs), errs


def test_duplicate_transition():
    scene_of = {1: 1, 2: 1, 3: 1}
    ap = {1: {"initial_anchor": {"kind": "motion", "reference_direction": "left-to-right"},
              "transitions": [
                  {"before_shot_index": 2, "kind": "intentional_jump", "reason": "x"},
                  {"before_shot_index": 2, "kind": "intentional_jump", "reason": "y"}]}}
    errs = validate_axis_plans(_creative(scene_of, ap), _ct(scene_of))
    assert any("duplicate before_shot_index" in e for e in errs), errs


def test_transition_on_first_shot_rejected():
    scene_of = {1: 1, 2: 1}  # scene 1 = shots 1,2; first shot = 1
    ap = {1: {"initial_anchor": {"kind": "motion", "reference_direction": "left-to-right"},
              "transitions": [{"before_shot_index": 1, "kind": "intentional_jump", "reason": "x"}]}}
    errs = validate_axis_plans(_creative(scene_of, ap), _ct(scene_of))
    assert any("first shot" in e for e in errs), errs


def test_out_of_range_before_shot_index_caught_semantically():
    # before_shot_index=0 must PARSE (no ge= constraint) and be caught by the validator,
    # so the degrade path can sanitize rather than Pydantic-aborting the whole output.
    scene_of = {1: 1, 2: 1}
    ap = {1: {"initial_anchor": {"kind": "motion", "reference_direction": "left-to-right"},
              "transitions": [{"before_shot_index": 0, "kind": "intentional_jump", "reason": "x"}]}}
    creative = _creative(scene_of, ap)  # must not raise
    errs = validate_axis_plans(creative, _ct(scene_of))
    assert any("not a shot in this scene" in e for e in errs), errs


def test_reestablish_requires_anchor():
    scene_of = {1: 1, 2: 1}
    ap = {1: {"initial_anchor": {"kind": "motion", "reference_direction": "left-to-right"},
              "transitions": [{"before_shot_index": 2, "kind": "re_establish", "reason": "x"}]}}
    errs = validate_axis_plans(_creative(scene_of, ap), _ct(scene_of))
    assert any("re_establish requires new_anchor" in e for e in errs), errs


def test_empty_reason_rejected():
    scene_of = {1: 1, 2: 1}
    ap = {1: {"initial_anchor": {"kind": "motion", "reference_direction": "left-to-right"},
              "transitions": [{"before_shot_index": 2, "kind": "intentional_jump", "reason": ""}]}}
    errs = validate_axis_plans(_creative(scene_of, ap), _ct(scene_of))
    assert any("empty reason" in e for e in errs), errs


def test_non_neutral_center_rejected():
    scene_of = {1: 1}
    ap = {1: {"initial_anchor": {"kind": "motion", "reference_direction": "center"}}}
    errs = validate_axis_plans(_creative(scene_of, ap), _ct(scene_of))
    assert any("must be lateral" in e for e in errs), errs


def test_non_neutral_toward_camera_rejected():
    scene_of = {1: 1}
    ap = {1: {"initial_anchor": {"kind": "motion", "reference_direction": "toward-camera"}}}
    errs = validate_axis_plans(_creative(scene_of, ap), _ct(scene_of))
    assert any("must be lateral" in e for e in errs), errs


def test_neutral_anchor_toward_camera_ok():
    scene_of = {1: 1}
    ap = {1: {"initial_anchor": {"kind": "neutral", "reference_direction": "toward-camera"}}}
    assert validate_axis_plans(_creative(scene_of, ap), _ct(scene_of)) == []


def test_neutral_anchor_lateral_rejected():
    scene_of = {1: 1}
    ap = {1: {"initial_anchor": {"kind": "neutral", "reference_direction": "left-to-right"}}}
    errs = validate_axis_plans(_creative(scene_of, ap), _ct(scene_of))
    assert any("must be on-axis" in e for e in errs), errs


def test_valid_plan_passes():
    scene_of = {1: 1, 2: 1, 3: 2}
    ap = {1: {"initial_anchor": {"kind": "motion", "reference_direction": "left-to-right"},
              "transitions": [{"before_shot_index": 2, "kind": "intentional_jump", "reason": "power shift"}]},
          2: _motion("right-to-left")}
    assert validate_axis_plans(_creative(scene_of, ap), _ct(scene_of)) == []


def test_sanitize_drops_invalid():
    scene_of = {1: 1, 2: 2}
    ap = {1: _motion("left-to-right"),  # valid
          2: {"initial_anchor": {"kind": "motion", "reference_direction": "center"}}}  # inert -> invalid
    creative = _creative(scene_of, ap)
    dropped = sanitize_axis_plans(creative, _ct(scene_of))
    assert dropped == [2]
    assert 1 in creative.axis_plans and 2 not in creative.axis_plans


def _vshot(scene_index, direction):
    return SimpleNamespace(scene_index=scene_index,
                           spatial_data=SimpleNamespace(screen_direction=direction))


def test_degenerate_variation_per_scene():
    # one AUTHORED scene of 4 all-center shots -> flagged (per-scene monotony)
    plan = SimpleNamespace(shots=[_vshot(1, ScreenDirection.CENTER) for _ in range(4)],
                           axis_plans={1: object()})
    warns = degenerate_variation(plan)
    assert len(warns) == 1 and "scene 1" in warns[0]
    # vary one -> no warning
    plan.shots[0].spatial_data.screen_direction = ScreenDirection.LEFT_TO_RIGHT
    assert degenerate_variation(plan) == []


def test_degenerate_variation_unauthored_scene_not_warned():
    # a 4-shot scene with NO authored plan is intentionally neutral -> no warning
    plan = SimpleNamespace(shots=[_vshot(1, ScreenDirection.CENTER) for _ in range(4)],
                           axis_plans={})
    assert degenerate_variation(plan) == []


def test_degenerate_variation_short_scene_not_flagged():
    plan = SimpleNamespace(shots=[_vshot(1, ScreenDirection.CENTER) for _ in range(3)],
                           axis_plans={1: object()})
    assert degenerate_variation(plan) == []
