from __future__ import annotations

from copy import deepcopy
from pathlib import Path

from recoil.pipeline._lib.dispatch_payload import PayloadContext
from recoil.pipeline._lib.plan_loader import CanonicalShot, CharacterEntry
from recoil.pipeline._lib.shot_primitive import (
    ShotPrimitive,
    primitive_from_payload_context,
    primitive_from_storyboard_shot,
)


def _storyboard_shot(shot_id: str = "EP001_SH11") -> dict:
    return {
        "shot_id": shot_id,
        "scene_index": 3,
        "source_text": "Jade crosses the med bay as Torch realizes the pod is opening.",
        "routing_data": {
            "target_editorial_duration_s": 6,
            "has_dialogue": True,
        },
        "prompt_data": {
            "shot_type": "OTS",
            "prompt_skeleton": {
                "subject_line": "Jade foreground shoulder, Torch at the pod hatch",
                "environment_line": "cramped Tartarus med bay",
                "action_line": "Torch braces one palm on the pod glass as vapor spills low.",
                "emotion_line": "Jade hides panic behind a clipped, controlled breath.",
            },
        },
        "spatial_data": {
            "camera_side": "B",
            "screen_direction": "left-to-right",
        },
        "asset_data": {
            "characters": [
                {"char_id": "JADE", "wardrobe_phase_id": "p1"},
                {"char_id": "TORCH", "wardrobe_phase_id": "p1"},
            ],
            "location_id": "tartarus_med_bay",
        },
        "provenance": {"episode": "ep_001", "pass": "storyboard"},
    }


def _canonical(raw: dict, *, duration_s: float | None = None) -> CanonicalShot:
    asset = raw.get("asset_data") or {}
    return CanonicalShot(
        shot_id=raw["shot_id"],
        scene_index=raw["scene_index"],
        sequence_id=None,
        pipeline="video",
        previs_model=None,
        video_model="seeddance-2.0",
        location_id=asset.get("location_id"),
        characters=[
            CharacterEntry(char_id=c["char_id"], wardrobe_phase_id=c.get("wardrobe_phase_id"))
            for c in asset.get("characters", [])
        ],
        shot_type=raw["prompt_data"]["shot_type"],
        duration_s=duration_s
        if duration_s is not None
        else float(raw["routing_data"]["target_editorial_duration_s"]),
        is_env_only=False,
        has_dialogue=bool(raw["routing_data"].get("has_dialogue")),
        aspect_ratio="9:16",
        raw=raw,
    )


def test_storyboard_shot_normalizes_without_mutating_source() -> None:
    shot = _storyboard_shot("EP001_SH11")
    before = deepcopy(shot)

    primitive = primitive_from_storyboard_shot(
        shot,
        scene_defaults={
            "location_id": "scene_default_location",
            "camera_side": "A",
            "screen_direction": "center",
        },
    )

    assert shot == before
    assert primitive == ShotPrimitive(
        shot_id="EP001_SH11",
        scene_index=3,
        shot_type="OTS",
        target_editorial_duration_s=6.0,
        intent=(
            "Jade crosses the med bay as Torch realizes the pod is opening. | "
            "Torch braces one palm on the pod glass as vapor spills low. | "
            "Jade hides panic behind a clipped, controlled breath."
        ),
        camera_side="B",
        screen_direction="left-to-right",
        has_dialogue=True,
        char_ids=["JADE", "TORCH"],
        location_id="tartarus_med_bay",
        timing_segments=[],
        strategy=None,
        refs={"provenance": {"episode": "ep_001", "pass": "storyboard"}},
    )


def test_storyboard_shot_inherits_shallow_scene_defaults_and_never_empty_intent() -> None:
    shot = _storyboard_shot("EP002_SH09")
    shot["asset_data"].pop("location_id")
    shot["spatial_data"].pop("screen_direction")
    shot["source_text"] = ""
    shot["prompt_data"]["prompt_skeleton"]["action_line"] = ""
    shot["prompt_data"]["prompt_skeleton"]["emotion_line"] = ""

    primitive = primitive_from_storyboard_shot(
        shot,
        scene_defaults={
            "location_id": "tartarus_airlock",
            "screen_direction": "right-to-left",
        },
    )

    assert primitive.scene_index == 3
    assert primitive.shot_type == "OTS"
    assert primitive.camera_side == "B"
    assert primitive.screen_direction == "right-to-left"
    assert primitive.target_editorial_duration_s == 6.0
    assert primitive.char_ids == ["JADE", "TORCH"]
    assert primitive.location_id == "tartarus_airlock"
    assert primitive.has_dialogue is True
    assert primitive.intent == "EP002_SH09 OTS shot intent"


def test_start_end_storyboard_shot_maps_refs() -> None:
    shot = _storyboard_shot("EP001_SH12")
    shot["start_frame"] = "/tmp/start.png"
    shot["end_frame"] = "/tmp/end.png"

    primitive = primitive_from_storyboard_shot(shot, scene_defaults={})

    assert primitive.refs["start_frame"] == "/tmp/start.png"
    assert primitive.refs["end_frame"] == "/tmp/end.png"


def test_payload_context_maps_r2v_batch() -> None:
    first = _storyboard_shot("EP001_SH11")
    second = _storyboard_shot("EP001_SH12")
    second["routing_data"]["target_editorial_duration_s"] = 4
    second["asset_data"]["characters"] = [{"char_id": "MARA", "wardrobe_phase_id": "p1"}]
    batch = [_canonical(first), _canonical(second)]
    ctx = PayloadContext(
        project="tartarus",
        modality="r2v_multi",
        shot_id="EP001_PASS_001",
        shot=batch[0],
        batch_shots=batch,
    )

    primitive = primitive_from_payload_context(
        ctx,
        ref_manifest={"identity_1": 1, "identity_2": 2, "scene_1": 3},
        segment_timestamps=[(0.0, 6.0), (6.0, 10.0)],
    )

    assert primitive.shot_id == "EP001_PASS_001"
    assert primitive.char_ids == ["JADE", "TORCH", "MARA"]
    assert primitive.target_editorial_duration_s == 10.0
    assert primitive.refs["manifest"] == {
        "identity_1": 1,
        "identity_2": 2,
        "scene_1": 3,
    }
    assert primitive.timing_segments == [
        {
            "shot_id": "EP001_SH11",
            "start_s": 0.0,
            "end_s": 6.0,
            "duration_s": 6.0,
            "intent": primitive.timing_segments[0]["intent"],
        },
        {
            "shot_id": "EP001_SH12",
            "start_s": 6.0,
            "end_s": 10.0,
            "duration_s": 4.0,
            "intent": primitive.timing_segments[1]["intent"],
        },
    ]
    assert primitive.timing_segments[0]["intent"]
    assert primitive.timing_segments[1]["intent"]


def test_payload_context_maps_i2v_resolved_frames() -> None:
    shot = _canonical(_storyboard_shot("EP001_SH13"))
    ctx = PayloadContext(
        project="tartarus",
        modality="video_i2v",
        shot_id=shot.shot_id,
        shot=shot,
    )

    primitive = primitive_from_payload_context(
        ctx,
        ref_manifest={"identity_1": 1},
        start_frame=Path("/tmp/start.png"),
        end_frame=Path("/tmp/end.png"),
    )

    assert primitive.shot_id == "EP001_SH13"
    assert primitive.refs["manifest"] == {"identity_1": 1}
    assert primitive.refs["start_frame"] == "/tmp/start.png"
    assert primitive.refs["end_frame"] == "/tmp/end.png"
    assert primitive.timing_segments == []
