from __future__ import annotations

from pathlib import Path

import pytest

from recoil.pipeline._lib.grouping import (
    ONER_PROMPT_DIRECTIVE,
    GroupingContext,
    get_grouping,
)
from recoil.pipeline._lib.plan_loader import CanonicalPlan, CanonicalShot, CharacterEntry
from recoil.pipeline.orchestrator.coverage_planner import CoveragePass, CoverageSegment


def _shot(
    shot_id: str,
    scene_index: int = 1,
    location_id: str = "LOC_A",
    shot_type: str = "MS",
    duration_s: float = 4.0,
) -> CanonicalShot:
    return CanonicalShot(
        shot_id=shot_id,
        scene_index=scene_index,
        sequence_id=None,
        pipeline="still",
        previs_model=None,
        video_model=None,
        location_id=location_id,
        characters=[CharacterEntry(char_id="JADE")],
        shot_type=shot_type,
        duration_s=duration_s,
        is_env_only=False,
        has_dialogue=False,
        aspect_ratio="9:16",
        raw={
            "shot_id": shot_id,
            "scene_index": scene_index,
            "routing_data": {"target_editorial_duration_s": duration_s},
            "asset_data": {
                "location_id": location_id,
                "characters": [{"char_id": "JADE"}],
            },
            "prompt_data": {"shot_type": shot_type},
            "spatial_data": {"camera_side": "A"},
        },
    )


def _plan(shots: list[CanonicalShot]) -> CanonicalPlan:
    return CanonicalPlan(
        episode_id="EP001",
        project="demo",
        shots=shots,
        source_path=Path("ep_001_plan.json"),
    )


def _ctx(
    shots: list[CanonicalShot],
    selected_coverage_passes: list[CoveragePass] | None = None,
) -> GroupingContext:
    return GroupingContext(
        project="demo",
        episode=1,
        canonical_plan=_plan(shots),
        selected_coverage_passes=list(selected_coverage_passes or []),
        tier_map={shot.shot_id: 2 for shot in shots},
        wildcard_override=None,
    )


def _coverage_pass(pass_id: str = "EP001_PASS_017_SH23_24_A_JADE") -> CoveragePass:
    return CoveragePass(
        pass_id=pass_id,
        episode_id="EP001",
        shot_range=("EP001_SH23", "EP001_SH24"),
        camera_side="A",
        label="JADE B (LOC_A)",
        focus_character="JADE",
        pass_type="character",
        location_id="LOC_A",
        segments=[
            CoverageSegment(
                segment_index=0,
                source_shot_id="EP001_SH23",
                shot_type="MS",
                duration_s=4,
                prompt="Jade waits.",
            ),
            CoverageSegment(
                segment_index=1,
                source_shot_id="EP001_SH24",
                shot_type="CU",
                duration_s=4,
                prompt="Jade looks up.",
            ),
            CoverageSegment(
                segment_index=2,
                source_shot_id="wildcard",
                shot_type="MS",
                duration_s=4,
                prompt="Jade reacts.",
                is_wildcard=True,
            ),
        ],
        element_config={"character_elements": [{"char_id": "JADE"}], "location_id": "LOC_A"},
        generation_config={"model": "seeddance-2.0", "mode": "i2v", "cfg_scale": 0.55},
        format_type="C",
        wildcard_enabled=True,
    )


def test_registry_returns_grouping_and_clear_unknown_error():
    assert get_grouping("coverage").name == "coverage"

    with pytest.raises(ValueError, match="Unknown grouping strategy 'missing'"):
        get_grouping("missing")


def test_continuity_identity_ordinals_use_public_registry():
    shots = [
        _shot("EP001_SH01", 1),
        _shot("EP001_SH02", 2),
        _shot("EP001_SH03", 3),
        _shot("EP001_SH04", 4),
        _shot("EP001_SH05", 5),
        _shot("EP001_SH06", 6),
        _shot("EP001_SH07", 7),
        _shot("EP001_SH08", 8),
    ]

    groups = get_grouping("continuity").assemble(shots, _ctx(shots))

    assert [group.identity.strategy for group in groups] == ["continuity", "continuity"]
    assert [group.identity.ordinal for group in groups] == [1, 2]
    assert [group.identity.shot_ids for group in groups] == [
        ["EP001_SH01", "EP001_SH02", "EP001_SH03", "EP001_SH04"],
        ["EP001_SH05", "EP001_SH06", "EP001_SH07", "EP001_SH08"],
    ]
    assert [group.modality for group in groups] == ["r2v_multi", "r2v_multi"]


def test_solo_identity_ordinal_is_zero():
    shots = [_shot("EP001_SH01"), _shot("EP001_SH02")]

    groups = get_grouping("solo").assemble(shots, _ctx(shots))

    assert [group.identity.strategy for group in groups] == ["solo", "solo"]
    assert [group.identity.ordinal for group in groups] == [0, 0]
    assert [group.modality for group in groups] == ["video_i2v", "video_i2v"]
    assert [group.identity.shot_ids for group in groups] == [["EP001_SH01"], ["EP001_SH02"]]


def test_coverage_identity_and_config_preservation_for_selected_passes():
    shots = [_shot("EP001_SH23", 1), _shot("EP001_SH24", 2)]
    selected = _coverage_pass()
    selected.source_pass_id = "loaded-source-pass"

    groups = get_grouping("coverage").assemble(shots, _ctx(shots, [selected]))

    assert len(groups) == 1
    group = groups[0]
    assert group.identity.strategy == "coverage"
    assert group.identity.ordinal == 17
    assert group.identity.shot_ids == ["EP001_SH23", "EP001_SH24"]
    assert group.identity.source_pass_id == "loaded-source-pass"
    assert group.scene_id == selected.pass_id
    assert group.coverage_pass is selected
    assert group.generation_config == selected.generation_config
    assert group.element_config == selected.element_config
    assert group.coverage_pass.format_type == "C"
    assert group.coverage_pass.wildcard_enabled is True


def test_coverage_generated_passes_call_planner_and_preserve_identity(monkeypatch):
    shots = [_shot("EP001_SH23", 1), _shot("EP001_SH24", 2)]
    generated = _coverage_pass("EP001_PASS_003_SH23_24_A_JADE")
    calls: dict[str, object] = {}

    def fake_build_passes(shots_dicts, project, episode, tier_map, wildcard_override):
        calls["shots_dicts"] = shots_dicts
        calls["project"] = project
        calls["episode"] = episode
        calls["tier_map"] = tier_map
        calls["wildcard_override"] = wildcard_override
        return [generated]

    monkeypatch.setattr(
        "recoil.pipeline._lib.grouping._build_coverage_passes",
        fake_build_passes,
    )

    groups = get_grouping("coverage").assemble(shots, _ctx(shots))

    assert calls["project"] == "demo"
    assert calls["episode"] == 1
    assert calls["tier_map"] == {"EP001_SH23": 2, "EP001_SH24": 2}
    assert calls["wildcard_override"] is None
    assert calls["shots_dicts"][0]["shot_id"] == "EP001_SH23"
    assert calls["shots_dicts"][0]["asset_data"]["location_id"] == "LOC_A"
    assert groups[0].identity.ordinal == 3
    assert groups[0].identity.source_pass_id == generated.pass_id
    assert groups[0].generation_config == generated.generation_config
    assert groups[0].element_config == generated.element_config


def test_oner_groups_by_scene_and_emits_format_c_prompt_directive():
    shots = [
        _shot("EP001_SH01", scene_index=1, location_id="LOC_A", shot_type="WS"),
        _shot("EP001_SH02", scene_index=1, location_id="LOC_B", shot_type="ECU"),
        _shot("EP001_SH03", scene_index=2, location_id="LOC_A", shot_type="MS"),
    ]

    groups = get_grouping("oner").assemble(shots, _ctx(shots))

    assert [group.scene_id for group in groups] == ["ONER_001", "ONER_002"]
    assert [group.identity.strategy for group in groups] == ["oner", "oner"]
    assert [group.identity.ordinal for group in groups] == [1, 2]
    assert [group.identity.shot_ids for group in groups] == [
        ["EP001_SH01", "EP001_SH02"],
        ["EP001_SH03"],
    ]
    assert all(group.prompt_directive for group in groups)
    assert groups[0].prompt_directive == ONER_PROMPT_DIRECTIVE
    assert ONER_PROMPT_DIRECTIVE.splitlines() == [
        "### Format C — Oner (Continuous Take)",
        "**What:** Generate one continuous 10-15 second shot. Extract edit points in post.",
        "**Like:** Shooting a long take on a Steadicam. Find the cuts later.",
        "**Use for:** Atmospheric moments, set pieces, climactic sequences where performance continuity matters.",
        "**Output:** One continuous clip. Editor finds the gold inside it.",
    ]


def test_oner_splits_scene_only_when_duration_cap_forces_distinct_ordinals():
    shots = [
        _shot(f"EP001_SH{i:02d}", scene_index=1, location_id=f"LOC_{i}", duration_s=4.0)
        for i in range(1, 6)
    ]

    groups = get_grouping("oner").assemble(shots, _ctx(shots))

    assert [group.identity.ordinal for group in groups] == [1, 2]
    assert [group.scene_id for group in groups] == ["ONER_001", "ONER_002"]
    assert [group.identity.shot_ids for group in groups] == [
        ["EP001_SH01", "EP001_SH02", "EP001_SH03"],
        ["EP001_SH04", "EP001_SH05"],
    ]
    assert [group.identity.strategy for group in groups] == ["oner", "oner"]
