"""Tests for axis_propagation (REC-180): deterministic per-shot spatial materialization."""
from __future__ import annotations

import copy

from recoil.pipeline._lib.render_schema import (
    AxisAnchor,
    AxisTransition,
    AxisTransitionKind,
    CutRelation,
    EpisodePlan,
    SceneAxisPlan,
    ScreenDirection,
    ShotRecord,
)
from recoil.pipeline.orchestrator.axis_validation import shot_index_of
from recoil.pipeline.orchestrator.axis_propagation import (

    mirror,
    project_direction,
    propagate_axis,
)


def _shot(shot_id: str, scene_index: int) -> ShotRecord:
    import re
    idx = int(re.search(r"_SH0*(\d+)", shot_id).group(1))
    return ShotRecord.model_validate(
        {
            "shot_id": shot_id,
            "shot_index": idx,
            "scene_index": scene_index,
            "source_text": "x",
            "routing_data": {
                "target_editorial_duration_s": 4, "has_dialogue": False,
                "camera_complexity": "static", "num_characters": 1, "is_env_only": False,
            },
            "prompt_data": {
                "shot_type": "CU", "camera_movement": "static",
                "lighting": {"dominant_source_index": 0, "sources": [
                    {"motivator": "m", "direction": "below", "quality": "soft", "color_temp": "cool"}]},
                "prompt_skeleton": {"subject_line": "a", "environment_line": "b",
                                    "action_line": "c", "motion_line": "d e", "emotion_line": "f"},
            },
            "spatial_data": {},
            "asset_data": {"location_id": "loc", "time_of_day": "interior",
                           "visual_mode": "reality", "characters": [], "props": []},
            "audio_data": {"dialogue": [], "ambient_sfx": "", "foley_action": ""},
        }
    )


def _plan(shots, axis_plans=None) -> EpisodePlan:
    return EpisodePlan(
        episode_id="EP001", project="demo", total_shots=len(shots),
        shots=shots, axis_plans=axis_plans or {},
    )


def _motion(direction="left-to-right") -> SceneAxisPlan:
    return SceneAxisPlan(initial_anchor=AxisAnchor(kind="motion", reference_direction=direction))


def test_mirror():
    assert mirror(ScreenDirection.LEFT_TO_RIGHT) == ScreenDirection.RIGHT_TO_LEFT
    assert mirror(ScreenDirection.RIGHT_TO_LEFT) == ScreenDirection.LEFT_TO_RIGHT
    assert mirror(ScreenDirection.CENTER) == ScreenDirection.CENTER
    assert mirror(ScreenDirection.TOWARD_CAMERA) == ScreenDirection.AWAY_FROM_CAMERA


def test_project_direction():
    a = AxisAnchor(kind="motion", reference_direction="left-to-right")
    assert project_direction(a, "A") == ScreenDirection.LEFT_TO_RIGHT
    assert project_direction(a, "B") == ScreenDirection.RIGHT_TO_LEFT
    n = AxisAnchor(kind="neutral")
    assert project_direction(n, "A") == ScreenDirection.CENTER


def test_shot_index_parse():
    assert shot_index_of(_shot("EP001_SH07", 1)) == 7
    assert shot_index_of(_shot("EP001_SH07A", 1)) == 7
    import pytest
    with pytest.raises(ValueError):
        shot_index_of(_shot("EP001_SH00", 1))


def test_transition_cutrelation_parity():
    for k in AxisTransitionKind:
        assert CutRelation(k.value) == CutRelation(k.value)  # no ValueError -> parallel values aligned


def test_propagate_basic():
    shots = [_shot(f"EP001_SH0{i}", 1) for i in range(1, 5)]
    plan = _plan(shots, {1: _motion("left-to-right")})
    propagate_axis(plan)
    for s in shots:
        assert s.spatial_data.camera_side == "A"
        assert s.spatial_data.screen_direction == ScreenDirection.LEFT_TO_RIGHT
        assert s.spatial_data.axis_segment_id == 0
    assert shots[0].spatial_data.cut_relation == CutRelation.SCENE_OPEN
    assert shots[1].spatial_data.cut_relation == CutRelation.CONSISTENT


def test_intentional_jump():
    shots = [_shot(f"EP001_SH0{i}", 1) for i in range(1, 5)]
    plan = _plan(shots, {1: SceneAxisPlan(
        initial_anchor=AxisAnchor(kind="motion", reference_direction="left-to-right"),
        transitions=[AxisTransition(before_shot_index=3, kind="intentional_jump", reason="power shift")],
    )})
    propagate_axis(plan)
    assert shots[1].spatial_data.camera_side == "A"
    assert shots[2].spatial_data.camera_side == "B"
    assert shots[2].spatial_data.screen_direction == ScreenDirection.RIGHT_TO_LEFT
    assert shots[2].spatial_data.cut_relation == CutRelation.INTENTIONAL_JUMP
    assert shots[2].spatial_data.axis_transition_reason == "power shift"
    assert shots[2].spatial_data.axis_segment_id == 0  # jump does not bump segment


def test_reestablish():
    shots = [_shot(f"EP001_SH0{i}", 1) for i in range(1, 5)]
    plan = _plan(shots, {1: SceneAxisPlan(
        initial_anchor=AxisAnchor(kind="dialogue", reference_direction="left-to-right"),
        transitions=[AxisTransition(before_shot_index=3, kind="re_establish", reason="turns to corridor",
                                    new_anchor=AxisAnchor(kind="motion", reference_direction="right-to-left"))],
    )})
    propagate_axis(plan)
    assert shots[2].spatial_data.axis_segment_id == 1
    assert shots[2].spatial_data.camera_side == "A"
    assert shots[2].spatial_data.screen_direction == ScreenDirection.RIGHT_TO_LEFT
    assert shots[2].spatial_data.axis_transition_reason == "turns to corridor"


def test_neutral_pivot_bumps_segment():
    shots = [_shot(f"EP001_SH0{i}", 1) for i in range(1, 5)]
    plan = _plan(shots, {1: SceneAxisPlan(
        initial_anchor=AxisAnchor(kind="motion", reference_direction="left-to-right"),
        transitions=[AxisTransition(before_shot_index=3, kind="neutral_pivot", reason="on-axis bridge")],
    )})
    propagate_axis(plan)
    assert shots[1].spatial_data.axis_segment_id == 0
    assert shots[2].spatial_data.axis_segment_id == 1          # pivot opens a new segment
    assert shots[2].spatial_data.camera_side == "A"           # reset to A
    assert shots[2].spatial_data.cut_relation == CutRelation.NEUTRAL_PIVOT
    assert shots[2].spatial_data.screen_direction == ScreenDirection.CENTER  # on-axis bridge


def test_jump_then_pivot_emerges_on_a():
    shots = [_shot(f"EP001_SH0{i}", 1) for i in range(1, 6)]
    plan = _plan(shots, {1: SceneAxisPlan(
        initial_anchor=AxisAnchor(kind="motion", reference_direction="left-to-right"),
        transitions=[
            AxisTransition(before_shot_index=2, kind="intentional_jump", reason="cross"),
            AxisTransition(before_shot_index=4, kind="neutral_pivot", reason="bridge"),
        ],
    )})
    propagate_axis(plan)
    assert shots[1].spatial_data.camera_side == "B"   # jumped at SH02
    assert shots[3].spatial_data.camera_side == "A"   # pivot at SH04 resets to A
    assert shots[3].spatial_data.screen_direction == ScreenDirection.CENTER
    assert shots[4].spatial_data.camera_side == "A"   # subsequent shot stays on the fresh side


def test_authoritative_shot_index_used():
    # shot_index field (not the shot_id digits) drives transition matching
    s1, s2 = _shot("EP001_SH01", 1), _shot("EP001_SH02", 1)
    assert s1.shot_index == 1 and s2.shot_index == 2
    plan = _plan([s1, s2], {1: SceneAxisPlan(
        initial_anchor=AxisAnchor(kind="motion", reference_direction="left-to-right"),
        transitions=[AxisTransition(before_shot_index=2, kind="intentional_jump", reason="x")],
    )})
    propagate_axis(plan)
    assert s2.spatial_data.camera_side == "B"


def test_neutral_fallback_missing_plan():
    shots = [_shot(f"EP001_SH0{i}", 1) for i in range(1, 4)]
    plan = _plan(shots, {})  # no axis_plan for scene 1
    propagate_axis(plan)
    for s in shots:
        assert s.spatial_data.camera_side == "A"
        assert s.spatial_data.screen_direction == ScreenDirection.CENTER


def test_invalid_present_plan_neutral_fallback():
    # a present-but-invalid plan (transition references an out-of-scene shot) must neutral-fallback,
    # never crash or mis-materialize.
    shots = [_shot(f"EP001_SH0{i}", 1) for i in range(1, 4)]
    bad = SceneAxisPlan(
        initial_anchor=AxisAnchor(kind="motion", reference_direction="left-to-right"),
        transitions=[AxisTransition(before_shot_index=99, kind="intentional_jump", reason="x")],
    )
    plan = _plan(shots, {1: bad})
    propagate_axis(plan)  # must not raise
    for s in shots:
        assert s.spatial_data.camera_side == "A"
        assert s.spatial_data.screen_direction == ScreenDirection.CENTER  # fell back to neutral
    assert 1 not in plan.axis_plans  # provenance reconciled: the dropped plan is removed


def test_noncontiguous_scene_neutral_fallback():
    # scene 1 -> scene 2 -> scene 1 (intercut). scene 1 is non-contiguous -> neutral-fallback,
    # NOT materialized from its (otherwise valid) motion plan.
    shots = [_shot("EP001_SH01", 1), _shot("EP001_SH02", 2), _shot("EP001_SH03", 1)]
    plan = _plan(shots, {1: _motion("left-to-right"), 2: _motion("right-to-left")})
    propagate_axis(plan)
    s1a = shots[0]; s1b = shots[2]
    assert s1a.spatial_data.screen_direction == ScreenDirection.CENTER  # fell back to neutral
    assert s1b.spatial_data.screen_direction == ScreenDirection.CENTER
    assert shots[1].spatial_data.screen_direction == ScreenDirection.RIGHT_TO_LEFT  # scene 2 fine
    assert 1 not in plan.axis_plans and 2 in plan.axis_plans  # only the dropped scene removed


def test_duplicate_shot_index_neutral_fallback():
    # SH07 and SH07A both parse/carry shot_index 7 in one scene -> ambiguous -> neutral-fallback
    shots = [_shot("EP001_SH07", 1), _shot("EP001_SH07A", 1)]
    plan = _plan(shots, {1: _motion("left-to-right")})
    propagate_axis(plan)
    for s in shots:
        assert s.spatial_data.screen_direction == ScreenDirection.CENTER


def test_camera_side_literal():
    shots = [_shot(f"EP001_SH0{i}", 1) for i in range(1, 5)]
    plan = _plan(shots, {1: SceneAxisPlan(
        initial_anchor=AxisAnchor(kind="motion", reference_direction="left-to-right"),
        transitions=[AxisTransition(before_shot_index=2, kind="intentional_jump", reason="x")],
    )})
    propagate_axis(plan)
    assert all(s.spatial_data.camera_side in {"A", "B"} for s in shots)


def test_reproducibility():
    shots = [_shot(f"EP001_SH0{i}", 1) for i in range(1, 5)]
    ap = {1: SceneAxisPlan(
        initial_anchor=AxisAnchor(kind="motion", reference_direction="left-to-right"),
        transitions=[AxisTransition(before_shot_index=3, kind="intentional_jump", reason="x")],
    )}
    p1 = _plan([_shot(s.shot_id, 1) for s in shots], copy.deepcopy(ap))
    p2 = _plan([_shot(s.shot_id, 1) for s in shots], copy.deepcopy(ap))
    propagate_axis(p1)
    propagate_axis(p2)
    assert [s.spatial_data.model_dump_json() for s in p1.shots] == [s.spatial_data.model_dump_json() for s in p2.shots]


def test_e2e_varied():
    shots = [_shot(f"EP001_SH0{i}", 1) for i in range(1, 4)] + [_shot(f"EP001_SH0{i}", 2) for i in range(4, 7)]
    plan = _plan(shots, {
        1: _motion("left-to-right"),
        2: _motion("right-to-left"),
    })
    propagate_axis(plan)
    dirs = {s.spatial_data.screen_direction for s in shots}
    assert len(dirs) >= 2
