from __future__ import annotations

import json
from collections import defaultdict
from pathlib import Path

import pytest

from recoil.pipeline._lib.author_strategies import (
    AUTHOR_STRATEGIES,
    DEFAULT_AUTHOR_STRATEGY,
    AuthorStrategy,
    StrategyResolutionError,
    resolve_strategy,
)
from recoil.pipeline._lib.shot_primitive import ShotPrimitive


_REPO_ROOT = Path(__file__).resolve().parents[3]


def _primitive(**overrides) -> ShotPrimitive:
    values = {
        "shot_id": "EP001_SH11",
        "scene_index": 1,
        "shot_type": "OTS",
        "target_editorial_duration_s": 6.0,
        "intent": "Jade crosses the med bay as Torch braces at the pod.",
        "camera_side": "B",
        "screen_direction": "left-to-right",
        "has_dialogue": False,
        "char_ids": ["JADE"],
        "location_id": "tartarus_med_bay",
        "timing_segments": [],
        "strategy": None,
        "refs": {},
    }
    values.update(overrides)
    return ShotPrimitive(**values)


def test_start_and_end_refs_resolve_to_video_i2v_start_end_frame() -> None:
    primitive = _primitive(refs={"start_frame": "/tmp/start.png", "end_frame": "/tmp/end.png"})

    modality, strategy = resolve_strategy(
        primitive,
        model_id="kling-v3",
    )

    assert modality == "video_i2v"
    assert strategy.name == "start_end_frame"


def test_multi_char_shape_resolves_to_r2v_multi_directed_prose() -> None:
    primitive = _primitive(
        char_ids=["JADE", "TORCH"],
        refs={"manifest": {"identity_1": 1, "identity_2": 2}},
    )

    modality, strategy = resolve_strategy(
        primitive,
        model_id="seeddance-2.0",
    )

    assert modality == "r2v_multi"
    assert strategy.name == "directed_prose"


def test_incompatible_requested_modality_blocks_loudly() -> None:
    primitive = _primitive(refs={"start_frame": "/tmp/start.png", "end_frame": "/tmp/end.png"})

    with pytest.raises(StrategyResolutionError, match="incompatible|drop the end frame"):
        resolve_strategy(
            primitive,
            model_id="seeddance-2.0",
            requested_modality="r2v_multi",
        )


def test_live_model_roles_resolve_to_locked_defaults() -> None:
    roles_path = _REPO_ROOT / "config" / "model_roles.json"
    roles = json.loads(roles_path.read_text(encoding="utf-8"))

    r2v_model = roles["video"]["multi_shot"]
    i2v_model = roles["video"]["i2v"]

    r2v_modality, r2v_strategy = resolve_strategy(
        _primitive(char_ids=["JADE"]),
        model_id=r2v_model,
        requested_modality="r2v_multi",
    )
    i2v_modality, i2v_strategy = resolve_strategy(
        _primitive(char_ids=["JADE"]),
        model_id=i2v_model,
        requested_modality="video_i2v",
    )

    assert (r2v_model, r2v_modality) == ("seeddance-2.0", "r2v_multi")
    assert r2v_strategy.name == "directed_prose"
    assert (i2v_model, i2v_modality) == ("kling-v3", "video_i2v")
    assert i2v_strategy.name == "start_end_frame"


def test_exactly_one_default_per_registered_model_modality() -> None:
    defaults_by_tuple: dict[tuple[str, str], list[str]] = defaultdict(list)
    for (model_id, modality, _name), strategy in AUTHOR_STRATEGIES.items():
        if strategy.is_default:
            defaults_by_tuple[(model_id, modality)].append(strategy.name)

    assert DEFAULT_AUTHOR_STRATEGY == {
        ("seeddance-2.0", "r2v_multi"): "directed_prose",
        ("kling-v3", "video_i2v"): "start_end_frame",
    }
    for model_modality in {
        (model_id, modality)
        for model_id, modality, _strategy_name in AUTHOR_STRATEGIES
    }:
        assert len(defaults_by_tuple[model_modality]) == 1


def test_unlisted_model_modality_defaults_to_deterministic_template() -> None:
    modality, strategy = resolve_strategy(
        _primitive(char_ids=["JADE"]),
        model_id="new-video-model",
        requested_modality="r2v_multi",
    )

    assert modality == "r2v_multi"
    assert strategy.name == "deterministic_template"


def test_extensibility_is_one_registry_line() -> None:
    key = ("seeddance-2.0", "r2v_multi", "silhouette_prose")
    AUTHOR_STRATEGIES[key] = AuthorStrategy(
        name="silhouette_prose",
        modality="r2v_multi",
        system_prompt_path=Path("silhouette_prose.md"),
        required_inputs=[],
        is_default=False,
        applies=lambda _primitive: True,
    )
    try:
        primitive = _primitive(strategy="silhouette_prose")
        modality, strategy = resolve_strategy(
            primitive,
            model_id="seeddance-2.0",
            requested_modality="r2v_multi",
        )
    finally:
        AUTHOR_STRATEGIES.pop(key, None)

    assert modality == "r2v_multi"
    assert strategy.name == "silhouette_prose"


def test_prose_author_fallback_env_short_circuits_to_template(
    monkeypatch: pytest.MonkeyPatch,
    caplog: pytest.LogCaptureFixture,
) -> None:
    monkeypatch.setenv("PROSE_AUTHOR_FALLBACK", "1")
    caplog.set_level("WARNING")

    modality, strategy = resolve_strategy(
        _primitive(char_ids=["JADE", "TORCH"]),
        model_id="seeddance-2.0",
    )

    assert modality == "r2v_multi"
    assert strategy.name == "deterministic_template"
    assert "prose_author_fallback" in caplog.text
    assert "reason=env_short_circuit" in caplog.text


def test_shot_spec_resolves_via_env_override(
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    monkeypatch.setenv("RECOIL_AUTHOR_STRATEGY", "shot_spec")

    modality, strategy = resolve_strategy(
        _primitive(
            char_ids=["JADE", "WREN"],
            timing_segments=[{"duration_s": 4.0}, {"duration_s": 4.0}],
        ),
        model_id="seeddance-2.0",
    )

    assert modality == "r2v_multi"
    assert strategy.name == "shot_spec"
    assert strategy.is_default is False
    assert strategy.system_prompt_path.is_file()


def test_shot_spec_resolves_via_primitive_strategy() -> None:
    modality, strategy = resolve_strategy(
        _primitive(
            strategy="shot_spec",
            char_ids=["JADE"],
            timing_segments=[{"duration_s": 4.0}],
        ),
        model_id="seeddance-2.0",
    )

    assert modality == "r2v_multi"
    assert strategy.name == "shot_spec"


def test_directed_prose_remains_default_without_env(
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    monkeypatch.delenv("RECOIL_AUTHOR_STRATEGY", raising=False)

    modality, strategy = resolve_strategy(
        _primitive(char_ids=["JADE", "WREN"]),
        model_id="seeddance-2.0",
    )

    assert modality == "r2v_multi"
    assert strategy.name == "directed_prose"
