from __future__ import annotations

from recoil.pipeline._lib.dispatch_payload import PayloadContext
from recoil.pipeline._lib.plan_loader import CanonicalShot
from recoil.pipeline._lib.shot_primitive import primitive_from_payload_context


def _shot(shot_id: str, *, spatial_data: dict) -> CanonicalShot:
    raw = {
        "shot_id": shot_id,
        "scene_index": 1,
        "prompt_data": {"shot_type": "MS"},
        "routing_data": {"target_editorial_duration_s": 3},
        "spatial_data": spatial_data,
    }
    return CanonicalShot(
        shot_id=shot_id,
        scene_index=1,
        sequence_id=None,
        pipeline="video",
        previs_model=None,
        video_model="seeddance-2.0",
        location_id=None,
        characters=[],
        shot_type="MS",
        duration_s=3.0,
        is_env_only=False,
        has_dialogue=False,
        aspect_ratio="9:16",
        raw=raw,
    )


def _ctx(batch: list[CanonicalShot]) -> PayloadContext:
    return PayloadContext(
        project="tartarus",
        modality="r2v_multi",
        shot_id="EP001_PASS_001",
        shot=batch[0],
        batch_shots=batch,
    )


def test_segment_sublocation_passthrough() -> None:
    batch = [
        _shot(
            "EP001_SH11",
            spatial_data={
                "sublocation": "pod_platform",
                "setting": "pod platform, beside the open cryo-pod",
            },
        ),
        _shot(
            "EP001_SH12",
            spatial_data={
                "sublocation": "anchor_cables",
                "setting": "anchor cables, the drop yawning below",
            },
        ),
    ]

    primitive = primitive_from_payload_context(_ctx(batch), ref_manifest=None)

    assert len(primitive.timing_segments) == 2
    for segment, source in zip(primitive.timing_segments, batch):
        assert segment["sublocation"] == source.raw["spatial_data"]["sublocation"]
        assert segment["setting"] == source.raw["spatial_data"]["setting"]


def test_segments_without_sublocation_unchanged() -> None:
    batch = [
        _shot("EP001_SH11", spatial_data={"camera_side": "A"}),
        _shot("EP001_SH12", spatial_data={}),
    ]

    primitive = primitive_from_payload_context(_ctx(batch), ref_manifest=None)

    assert len(primitive.timing_segments) == 2
    for segment in primitive.timing_segments:
        assert "sublocation" not in segment
        assert "setting" not in segment
