"""Unit tests for cluster_shots_into_batches (CanonicalShot input)."""
from __future__ import annotations

import sys
from pathlib import Path

_REPO_ROOT = Path(__file__).resolve().parents[4]
if str(_REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(_REPO_ROOT))

from recoil.pipeline._lib.scene_clusterer import cluster_shots_into_batches  # noqa: E402
from recoil.pipeline._lib.plan_loader import CanonicalShot  # noqa: E402


def _shot(shot_id, scene, loc, stype="MS", chars=None, dur=3):
    return CanonicalShot(
        shot_id=shot_id, scene_index=scene, sequence_id=None, pipeline="still",
        previs_model=None, video_model=None,
        location_id=loc, characters=list(chars or []),
        shot_type=stype, duration_s=float(dur),
        is_env_only=False, has_dialogue=False, aspect_ratio=None, raw={},
    )


def test_empty_input_returns_empty():
    assert cluster_shots_into_batches([]) == []


def test_single_location_groups_into_one_batch():
    shots = [
        _shot("A", 1, "L1"), _shot("B", 2, "L1"),
        _shot("C", 3, "L1"), _shot("D", 4, "L1"),
    ]
    batches = cluster_shots_into_batches(shots)
    assert len(batches) == 1
    assert len(batches[0].shots) == 4
    assert batches[0].shared_location_id == "L1"


def test_location_break_starts_new_batch():
    shots = [
        _shot("A", 1, "L1"), _shot("B", 2, "L1"),
        _shot("C", 3, "L2"),
    ]
    batches = cluster_shots_into_batches(shots)
    assert len(batches) == 2
    assert batches[0].shared_location_id == "L1"
    assert batches[1].shared_location_id == "L2"
    assert batches[1].below_threshold is True


def test_max_batch_size_caps_at_explicit_6():
    """Validates the >5 packing path when callers opt into fal.ai's full cap."""
    shots = [_shot(f"S{i}", i, "L1") for i in range(10)]
    batches = cluster_shots_into_batches(shots, max_batch_size=6)
    assert len(batches) == 2
    assert len(batches[0].shots) == 6
    assert len(batches[1].shots) == 4


def test_max_batch_size_default_caps_at_4():
    """Default cap is 4 per JT narrative-pacing review (2026-05-17)."""
    shots = [_shot(f"S{i}", i, "L1") for i in range(10)]
    batches = cluster_shots_into_batches(shots)
    # 10 shots → [4, 4, 2]. Last batch is below min_batch_size=3,
    # so it should be flagged below_threshold (the orphan signal).
    assert len(batches) == 3
    assert len(batches[0].shots) == 4
    assert len(batches[1].shots) == 4
    assert len(batches[2].shots) == 2
    assert batches[2].below_threshold is True


def test_scene_index_gap_breaks_batch():
    shots = [
        _shot("A", 1, "L1"), _shot("B", 2, "L1"),
        _shot("D", 9, "L1"),
    ]
    batches = cluster_shots_into_batches(shots)
    assert len(batches) == 2


def test_drastic_angle_jump_breaks_batch():
    shots = [
        _shot("A", 1, "L1", stype="WS"),
        _shot("B", 1, "L1", stype="ECU"),
    ]
    batches = cluster_shots_into_batches(shots)
    assert len(batches) == 2


def test_drastic_angle_jump_checks_adjacent_not_first(monkeypatch):
    """Angle-jump detection must compare shot N-1 → N, not batch[0] → N.

    WS→MS is fine; MS→ECU is fine; but WS→ECU is a drastic jump.
    A 3-shot WS→MS→ECU sequence should stay in ONE batch because each
    adjacent pair is fine. Buggy code (comparing against batch[0]=WS)
    incorrectly splits at shot 3.
    """
    monkeypatch.setattr(
        "recoil.pipeline._lib.scene_clusterer._DRASTIC_JUMP",
        {"WS": frozenset({"ECU"}), "ECU": frozenset({"WS"})},
    )
    shots = [
        _shot("A", 1, "L1", stype="WS"),
        _shot("B", 2, "L1", stype="MS"),
        _shot("C", 3, "L1", stype="ECU"),
    ]
    batches = cluster_shots_into_batches(shots)
    # WS→MS: fine. MS→ECU: fine (only WS↔ECU is drastic in our patch).
    # All 3 shots should be in one batch.
    assert len(batches) == 1, (
        f"Expected 1 batch (no drastic adjacent jumps) but got {len(batches)}: "
        + ", ".join(str([s.shot_id for s in b.shots]) for b in batches)
    )
    assert len(batches[0].shots) == 3


def test_below_threshold_flag_set_correctly():
    shots = [_shot("A", 1, "L1"), _shot("B", 1, "L1")]
    batches = cluster_shots_into_batches(shots)
    assert len(batches) == 1
    assert batches[0].below_threshold is True


def test_character_dedup_and_order():
    shots = [
        _shot("A", 1, "L1", chars=["JADE"]),
        _shot("B", 2, "L1", chars=["WREN", "JADE"]),
        _shot("C", 3, "L1", chars=["WREN"]),
    ]
    batches = cluster_shots_into_batches(shots)
    assert batches[0].shared_characters == ["JADE", "WREN"]


def test_total_duration_sum():
    shots = [
        _shot("A", 1, "L1", dur=3),
        _shot("B", 2, "L1", dur=4),
        _shot("C", 3, "L1", dur=5),
    ]
    batches = cluster_shots_into_batches(shots)
    assert batches[0].total_duration_s == 12.0


def test_scene_index_range_populated():
    shots = [_shot("A", 5, "L1"), _shot("B", 6, "L1"), _shot("C", 7, "L1")]
    batches = cluster_shots_into_batches(shots)
    assert batches[0].scene_index_range == (5, 7)
