from __future__ import annotations

from pathlib import Path

import pytest

from recoil.pipeline._lib import dispatch_payload as dp
from recoil.pipeline._lib.plan_loader import CanonicalShot, CharacterEntry


PNG_BYTES = b"\x89PNG\r\n\x1a\n" + b"\x00" * 32


def _i2v_shot(start_frame: str, end_frame: str) -> CanonicalShot:
    raw = {
        "shot_id": "EP001_SH11",
        "scene_index": 1,
        "duration_s": 5.0,
        "start_frame": start_frame,
        "end_frame": end_frame,
        "source_text": "A pod opens from the first frame into the final frame.",
        "asset_data": {"characters": [], "location_id": "pod_bay"},
        "prompt_data": {
            "shot_type": "CU",
            "action_line": "The pod opens.",
            "emotion_line": "Contained pressure releases.",
        },
    }
    return CanonicalShot(
        shot_id="EP001_SH11",
        scene_index=1,
        sequence_id=None,
        pipeline="video",
        previs_model="gemini-3-pro-image-preview",
        video_model="kling-v3",
        location_id="pod_bay",
        characters=[],
        shot_type="CU",
        duration_s=5.0,
        is_env_only=True,
        has_dialogue=False,
        aspect_ratio="9:16",
        raw=raw,
    )


def _r2v_shot(shot_id: str, char_id: str) -> CanonicalShot:
    raw = {
        "shot_id": shot_id,
        "scene_index": 2,
        "duration_s": 3.0,
        "shot_type": "OTS",
        "camera_side": "B",
        "screen_direction": "left-to-right",
        "source_text": f"{char_id.title()} crosses the med bay.",
        "asset_data": {"characters": [char_id], "location_id": "med_bay"},
        "prompt_data": {
            "shot_type": "OTS",
            "action_line": f"{char_id.title()} crosses the med bay.",
            "emotion_line": "Breath held tight.",
        },
    }
    return CanonicalShot(
        shot_id=shot_id,
        scene_index=2,
        sequence_id=None,
        pipeline="video",
        previs_model="gemini-3-pro-image-preview",
        video_model="seeddance-2.0",
        location_id="med_bay",
        characters=[CharacterEntry(char_id=char_id)],
        shot_type="OTS",
        duration_s=3.0,
        is_env_only=False,
        has_dialogue=False,
        aspect_ratio="9:16",
        raw=raw,
    )


@pytest.fixture(autouse=True)
def _patch_common(monkeypatch: pytest.MonkeyPatch):
    dp._project_config_cache.clear()
    monkeypatch.setattr(dp, "load_project_config", lambda _project: {})
    monkeypatch.setattr(dp, "_collect_reference_images", lambda *a, **k: ([], {}))
    yield
    dp._project_config_cache.clear()


def test_video_i2v_dry_run_serializes_bound_start_and_tail_not_image(
    tmp_path: Path,
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    start = tmp_path / "start.png"
    end = tmp_path / "end.png"
    start.write_bytes(PNG_BYTES)
    end.write_bytes(PNG_BYTES)
    monkeypatch.setattr(
        dp,
        "author_pass",
        lambda *a, **k: "The pod opens from the first frame into the final frame.",
    )

    payload = dp.build_dispatch_payload(
        shot=_i2v_shot(str(start), str(end)),
        project="tartarus",
        modality="video_i2v",
        model_override="kling-v3",
        dry_run=True,
    )

    assert payload["start_frame"] == str(start)
    assert payload["image_tail"]
    assert "image" not in payload


def test_r2v_multi_serializes_ref_stack_and_never_image_tail(
    tmp_path: Path,
) -> None:
    end = tmp_path / "end.png"
    end.write_bytes(PNG_BYTES)
    jade = _r2v_shot("EP002_SH01", "JADE")
    wren = _r2v_shot("EP002_SH02", "WREN")

    payload = dp.build_unified_payload(
        dp.PayloadContext(
            project="tartarus",
            modality="r2v_multi",
            shot_id="EP002_PASS_009",
            prompt="@Image1 crosses. @Image2 follows.",
            end_frame_path=end,
            reference_image_paths=[
                Path("/tmp/jade.png"),
                Path("/tmp/wren.png"),
            ],
            shot=jade,
            batch_shots=[jade, wren],
            model_id="seeddance-2.0",
        )
    )

    assert payload["reference_images"] == ["/tmp/jade.png", "/tmp/wren.png"]
    assert "image_tail" not in payload


def test_i2v_bound_missing_start_frame_blocks(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
    start = tmp_path / "start.png"
    end = tmp_path / "end.png"
    start.write_bytes(PNG_BYTES)
    end.write_bytes(PNG_BYTES)

    monkeypatch.setattr(
        dp,
        "_build_author_aware_prompt",
        lambda *a, **k: dp.AuthorPromptResult(
            prompt="Bound brief.",
            modality="video_i2v",
            strategy="start_end_frame",
            payload_refs={"image_tail": str(end)},
        ),
    )

    with pytest.raises(dp.DispatchPayloadError, match="payload_refs.start_frame"):
        dp.build_unified_payload(
            dp.PayloadContext(
                project="tartarus",
                modality="video_i2v",
                shot_id="EP001_SH11",
                start_frame_path=start,
                shot=_i2v_shot(str(start), str(end)),
                model_id="kling-v3",
            )
        )
