from __future__ import annotations

import json
from pathlib import Path

import pytest

from recoil.pipeline._lib import dispatch_payload as dp
from recoil.pipeline._lib.author_strategies import AuthorInputError, resolve_strategy
from recoil.pipeline._lib.opus_oauth import OpusOAuthError
from recoil.pipeline._lib.plan_loader import CanonicalShot, CharacterEntry
from recoil.pipeline._lib.prompt_engine import BindAssertionError
from recoil.pipeline._lib.shot_primitive import ShotPrimitive


def _primitive() -> ShotPrimitive:
    return ShotPrimitive(
        shot_id="EP002_PASS_009",
        scene_index=2,
        shot_type="OTS",
        target_editorial_duration_s=6.0,
        intent="Jade and Wren cross the med bay under pressure.",
        camera_side="B",
        screen_direction="left-to-right",
        char_ids=["JADE", "WREN"],
        timing_segments=[
            {"start_s": 0.0, "end_s": 3.0, "intent": "Jade moves first."},
            {"start_s": 3.0, "end_s": 6.0, "intent": "Wren answers."},
        ],
        refs={"manifest": {"identity_1": 1, "identity_2": 2}},
    )


def _shot(shot_id: str, char_id: str, duration_s: float = 3.0) -> CanonicalShot:
    raw = {
        "shot_id": shot_id,
        "scene_index": 2,
        "duration_s": duration_s,
        "shot_type": "OTS",
        "camera_side": "B",
        "screen_direction": "left-to-right",
        "source_text": f"{char_id.title()} moves through pressure.",
        "asset_data": {"characters": [char_id]},
        "prompt_data": {
            "shot_type": "OTS",
            "action_line": f"{char_id.title()} crosses the med bay.",
            "emotion_line": "Breath held tight.",
        },
    }
    return CanonicalShot(
        shot_id=shot_id,
        scene_index=2,
        sequence_id=None,
        pipeline="video",
        previs_model="gemini-3-pro-image-preview",
        video_model="seeddance-2.0",
        location_id="tartarus_med_bay",
        characters=[CharacterEntry(char_id=char_id)],
        shot_type="OTS",
        duration_s=duration_s,
        is_env_only=False,
        has_dialogue=False,
        aspect_ratio="9:16",
        raw=raw,
    )


def _ctx() -> dp.PayloadContext:
    jade = _shot("EP002_SH01", "JADE")
    wren = _shot("EP002_SH02", "WREN")
    return dp.PayloadContext(
        project="tartarus",
        modality="r2v_multi",
        shot_id="EP002_PASS_009",
        shot=jade,
        batch_shots=[jade, wren],
        model_id="seeddance-2.0",
        bible={},
    )


@pytest.fixture(autouse=True)
def _patch_common(monkeypatch: pytest.MonkeyPatch):
    dp._project_config_cache.clear()
    monkeypatch.setattr(dp, "load_project_config", lambda _project: {})
    monkeypatch.setattr(
        dp,
        "_collect_reference_images",
        lambda *args, **kwargs: (
            ["/tmp/jade.png", "/tmp/wren.png"],
            {"identity_1": 1, "identity_2": 2},
        ),
    )
    yield
    dp._project_config_cache.clear()


def test_author_pass_resolves_prose_author_model_inside(
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    _modality, strategy = resolve_strategy(_primitive(), model_id="seeddance-2.0")
    seen: dict[str, str] = {}

    monkeypatch.setattr(
        dp,
        "get_model",
        lambda role, category: "claude-opus-4-8"
        if (role, category) == ("prose_author", "text")
        else "wrong",
    )

    def fake_call(model_id: str, _system: str, _user: str) -> str:
        seen["model_id"] = model_id
        return "[0:00-0:03] Jade pushes in.\n[0:03-0:06] Wren answers."

    monkeypatch.setattr(dp, "_call_author_model", fake_call)

    text = dp.author_pass(
        _primitive(),
        strategy,
        bible={},
        project_config={},
        ref_manifest={"identity_1": 1, "identity_2": 2},
    )

    assert seen["model_id"] == "claude-opus-4-8"
    assert "Jade" in text


def test_call_author_model_routes_through_opus_oauth(
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    seen: dict[str, str] = {}

    def fake_call(model_id: str, system_prompt: str, user_prompt: str) -> str:
        seen["model_id"] = model_id
        seen["system_prompt"] = system_prompt
        seen["user_prompt"] = user_prompt
        return "  [0:00-0:03] Jade pushes in.  \n"

    monkeypatch.setattr(dp, "call_opus_oauth", fake_call)

    text = dp._call_author_model("claude-opus-4-8", "system", "user")

    assert text == "[0:00-0:03] Jade pushes in."
    assert seen == {
        "model_id": "claude-opus-4-8",
        "system_prompt": "system",
        "user_prompt": "user",
    }


def test_call_author_model_opus_oauth_error_falls_back_to_template(
    monkeypatch: pytest.MonkeyPatch,
    caplog: pytest.LogCaptureFixture,
) -> None:
    caplog.set_level("WARNING")

    def fake_call(_model_id: str, _system_prompt: str, _user_prompt: str) -> str:
        raise OpusOAuthError(
            "claude OAuth call failed",
            model_id="claude-opus-4-8",
            returncode=1,
            stderr="auth failed",
        )

    monkeypatch.setattr(dp, "call_opus_oauth", fake_call)
    monkeypatch.setattr(dp, "get_builder", lambda *args, **kwargs: lambda *a, **k: "TEMPLATE")

    result = dp._build_author_aware_prompt(
        _ctx(),
        model_id="seeddance-2.0",
        ref_manifest={"identity_1": 1, "identity_2": 2},
        segment_timestamps=[0.0, 3.0],
        primitive_segment_timestamps=[(0.0, 3.0), (3.0, 6.0)],
        bible={},
        project_config={},
    )

    assert result.prompt == "TEMPLATE"
    assert result.strategy == "deterministic_template"
    assert "prose_author_fallback" in caplog.text
    assert "reason=author_call" in caplog.text


def test_shared_helper_used_by_live_and_audit(
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    calls: list[str] = []
    real = dp._build_author_aware_prompt

    def wrapped(ctx, **kwargs):
        calls.append(ctx.shot_id)
        return real(ctx, **kwargs)

    monkeypatch.setattr(dp, "_build_author_aware_prompt", wrapped)
    monkeypatch.setattr(
        dp,
        "author_pass",
        lambda *args, **kwargs: (
            "[0:00-0:03] Jade leans forward.\n"
            "[0:03-0:06] Wren braces beside Jade."
        ),
    )

    dp.build_unified_payload(_ctx())
    dp.build_dispatch_payload(
        shot=_shot("EP002_SH01", "JADE"),
        project="tartarus",
        modality="r2v_multi",
        batch_shots=[_shot("EP002_SH01", "JADE"), _shot("EP002_SH02", "WREN")],
        dry_run=True,
    )

    assert calls == ["EP002_PASS_009", "EP002_SH01"]


def test_authored_real_names_are_bound_before_payload_prompt(
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    monkeypatch.setattr(
        dp,
        "author_pass",
        lambda *args, **kwargs: (
            "[0:00-0:03] Jade pushes through Tartarus Med Bay as the lens "
            "drives with her tight breath.\n"
            "[0:03-0:06] Wren catches Jade by the sleeve, jaw locked as the "
            "move racks onto his hand."
        ),
    )

    payload = dp.build_unified_payload(_ctx())

    assert "@Image1" in payload["prompt"]
    assert "@Image2" in payload["prompt"]
    assert "Jade" not in payload["prompt"]
    assert "Wren" not in payload["prompt"]


@pytest.mark.parametrize(
    ("setup", "reason"),
    [
        ("env", "env_short_circuit"),
        ("input", "author_input"),
        ("call", "author_call"),
        ("bind", "bind_assertion"),
    ],
)
def test_author_failures_fallback_to_deterministic_template(
    monkeypatch: pytest.MonkeyPatch,
    caplog: pytest.LogCaptureFixture,
    setup: str,
    reason: str,
) -> None:
    caplog.set_level("WARNING")
    if setup == "env":
        monkeypatch.setenv("PROSE_AUTHOR_FALLBACK", "1")
    elif setup == "input":
        monkeypatch.setattr(
            dp,
            "author_pass",
            lambda *args, **kwargs: (_ for _ in ()).throw(
                AuthorInputError("missing input")
            ),
        )
    elif setup == "call":
        monkeypatch.setattr(
            dp,
            "author_pass",
            lambda *args, **kwargs: (_ for _ in ()).throw(
                dp.AuthorCallError("opus failed")
            ),
        )
    elif setup == "bind":
        monkeypatch.setattr(
            dp,
            "author_pass",
            lambda *args, **kwargs: (
                "[0:00-0:03] Jade moves.\n[0:03-0:06] Wren follows."
            ),
        )
        monkeypatch.setattr(
            dp,
            "bind_named_prose",
            lambda *args, **kwargs: (_ for _ in ()).throw(
                BindAssertionError("bad bind")
            ),
        )

    monkeypatch.setattr(dp, "get_builder", lambda *args, **kwargs: lambda *a, **k: "TEMPLATE")

    result = dp._build_author_aware_prompt(
        _ctx(),
        model_id="seeddance-2.0",
        ref_manifest={"identity_1": 1, "identity_2": 2},
        segment_timestamps=[0.0, 3.0],
        primitive_segment_timestamps=[(0.0, 3.0), (3.0, 6.0)],
        bible={},
        project_config={},
    )

    assert result.prompt == "TEMPLATE"
    assert result.strategy == "deterministic_template"
    assert "prose_author_fallback" in caplog.text
    assert f"reason={reason}" in caplog.text


def test_prompt_files_exist() -> None:
    root = Path("recoil/pipeline/_lib/prompts/strategies")
    assert (root / "directed_prose.md").is_file()
    assert (root / "shot_spec.md").is_file()
    assert (root / "start_end_frame.md").is_file()


def test_beat_skeleton_carries_setting() -> None:
    primitive = _primitive()
    setting = "pod platform, beside the open cryo-pod"
    for seg in primitive.timing_segments:
        seg["setting"] = setting

    _modality, strategy = resolve_strategy(primitive, model_id="seeddance-2.0")
    rendered = dp._render_author_user_prompt(
        primitive,
        strategy,
        bible={},
        project_config={},
        ref_manifest={"identity_1": 1, "identity_2": 2},
    )

    payload = json.loads(rendered)
    skeleton = payload["beat_skeleton"]
    assert len(skeleton) == 2
    for entry in skeleton:
        assert entry["setting"] == setting
