from __future__ import annotations

import copy

import pytest

from recoil.pipeline._lib import dispatch_payload as dp
from recoil.pipeline._lib import world_state_pass as wsp
from recoil.pipeline._lib.plan_loader import CanonicalShot, CharacterEntry


def _segments() -> list[dict]:
    return [
        {
            "shot_id": "EP001_SH01",
            "start_s": 0.0,
            "end_s": 3.0,
            "duration_s": 3.0,
            "intent": "Jade opens the cryo-pod.",
            "sublocation": "pod_platform",
        },
        {
            "shot_id": "EP001_SH02",
            "start_s": 3.0,
            "end_s": 6.0,
            "duration_s": 3.0,
            "intent": "Wren sits up inside the pod.",
        },
    ]


# ── derive_settings: mapping ────────────────────────────────────────────


def test_derive_settings_maps_lines_to_segments(
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    lines = (
        "Pod platform: Jade beside the open cryo-pod.\n"
        "Pod platform: Wren sits up inside the open pod; Jade beside it."
    )
    monkeypatch.setattr(wsp, "_call_world_state_model", lambda *a, **k: lines)

    result = wsp.derive_settings(
        _segments(),
        location_id="int_shaft",
        char_ids=["JADE", "WREN"],
        model="claude-test",
    )

    assert [s["setting"] for s in result] == [
        "Pod platform: Jade beside the open cryo-pod.",
        "Pod platform: Wren sits up inside the open pod; Jade beside it.",
    ]
    # Pre-existing keys round-trip unchanged.
    assert result[0]["shot_id"] == "EP001_SH01"
    assert result[1]["intent"] == "Wren sits up inside the pod."


# ── derive_settings: length cap (word-boundary truncation) ──────────────


def test_derive_settings_truncates_over_cap_at_word_boundary(
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    long_line = "word " * 40  # 200 chars, every gap a word boundary
    monkeypatch.setattr(
        wsp,
        "_call_world_state_model",
        lambda *a, **k: f"{long_line}\nshort second line",
    )

    result = wsp.derive_settings(
        _segments(),
        location_id=None,
        char_ids=[],
        model="claude-test",
    )

    setting0 = result[0]["setting"]
    assert len(setting0) <= wsp.SETTING_CHAR_LIMIT
    assert not setting0.endswith(" ")
    # Cut at a word boundary → only whole "word" tokens, no partial fragment.
    # seg[0] is sublocation-tagged, so the repaired "Pod platform:" prefix
    # joins the whole-word tokens (merge-gate r11 anchoring contract).
    assert set(setting0.split()) <= {"Pod", "platform:", "word"}
    assert "word" in setting0
    # Sub-cap line passes through untouched.
    assert result[1]["setting"] == "short second line"


# ── derive_settings: fail-soft on transport exception ───────────────────


def test_derive_settings_failsoft_on_transport_exception(
    monkeypatch: pytest.MonkeyPatch,
    caplog: pytest.LogCaptureFixture,
) -> None:
    caplog.set_level("WARNING")

    def boom(*_a, **_k):
        raise RuntimeError("model timeout")

    monkeypatch.setattr(wsp, "_call_world_state_model", boom)

    segs = _segments()
    segs[0]["setting"] = "pre-existing setting"  # must survive untouched

    result = wsp.derive_settings(
        segs,
        location_id=None,
        char_ids=[],
        model="claude-test",
    )

    assert result is segs  # input returned unchanged
    assert result[0]["setting"] == "pre-existing setting"  # preserved, not stripped
    assert "setting" not in result[1]  # none added
    assert "world_state_pass_skipped" in caplog.text


def test_derive_settings_failsoft_on_line_count_mismatch(
    monkeypatch: pytest.MonkeyPatch,
    caplog: pytest.LogCaptureFixture,
) -> None:
    caplog.set_level("WARNING")
    # Three lines for two segments → malformed → fail soft.
    monkeypatch.setattr(
        wsp,
        "_call_world_state_model",
        lambda *a, **k: "line one\nline two\nline three",
    )

    segs = _segments()
    result = wsp.derive_settings(
        segs,
        location_id=None,
        char_ids=[],
        model="claude-test",
    )

    assert result is segs
    assert all("setting" not in s for s in result)
    assert "world_state_pass_skipped" in caplog.text


# ── derive_settings: zero mutation of the input list ────────────────────


def test_derive_settings_does_not_mutate_input(
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    monkeypatch.setattr(
        wsp,
        "_call_world_state_model",
        lambda *a, **k: "Setting one line.\nSetting two line.",
    )

    segs = _segments()
    before = copy.deepcopy(segs)
    result = wsp.derive_settings(
        segs,
        location_id=None,
        char_ids=[],
        model="claude-test",
    )

    # Input list + its dicts are untouched.
    assert segs == before
    assert all("setting" not in s for s in segs)
    # Output is a NEW list carrying the derived settings.
    assert result is not segs
    assert [s["setting"] for s in result] == [
        "Pod platform: Setting one line.", "Setting two line."
    ]


# ── wiring: authored r2v_multi build seam ───────────────────────────────


def _shot(shot_id: str, char_id: str, duration_s: float = 3.0) -> CanonicalShot:
    raw = {
        "shot_id": shot_id,
        "scene_index": 2,
        "duration_s": duration_s,
        "shot_type": "OTS",
        "camera_side": "B",
        "screen_direction": "left-to-right",
        "strategy": "shot_spec",  # force shot_spec resolution
        "source_text": f"{char_id.title()} moves through pressure.",
        "asset_data": {"characters": [char_id]},
        "prompt_data": {
            "shot_type": "OTS",
            "action_line": f"{char_id.title()} crosses the bay.",
            "emotion_line": "Breath held tight.",
        },
    }
    return CanonicalShot(
        shot_id=shot_id,
        scene_index=2,
        sequence_id=None,
        pipeline="video",
        previs_model="gemini-3-pro-image-preview",
        video_model="seeddance-2.0",
        location_id=None,  # no location → registry path skipped
        characters=[CharacterEntry(char_id=char_id)],
        shot_type="OTS",
        duration_s=duration_s,
        is_env_only=False,
        has_dialogue=False,
        aspect_ratio="9:16",
        raw=raw,
    )


def _ctx() -> dp.PayloadContext:
    jade = _shot("EP002_SH01", "JADE")
    wren = _shot("EP002_SH02", "WREN")
    return dp.PayloadContext(
        project="tartarus",
        modality="r2v_multi",
        shot_id="EP002_PASS_009",
        shot=jade,
        batch_shots=[jade, wren],
        model_id="seeddance-2.0",
        bible={},
    )


@pytest.fixture
def _patch_common(monkeypatch: pytest.MonkeyPatch):
    dp._project_config_cache.clear()
    monkeypatch.setattr(dp, "load_project_config", lambda _project: {})
    monkeypatch.setattr(
        dp,
        "_collect_reference_images",
        lambda *args, **kwargs: (
            ["/tmp/jade.png", "/tmp/wren.png"],
            {"identity_1": 1, "identity_2": 2},
        ),
    )
    yield
    dp._project_config_cache.clear()


def test_authored_path_injects_settings_when_enabled(
    monkeypatch: pytest.MonkeyPatch,
    _patch_common,
) -> None:
    fake_lines = "Bay east: Jade by the hatch.\nBay east: Wren braces beside Jade."
    monkeypatch.setattr(wsp, "_call_world_state_model", lambda *a, **k: fake_lines)

    # Spy on derive_settings: capture its (original) input segments while
    # delegating to the REAL implementation so the live mapping + replace
    # wire-in is exercised end to end.
    real_derive = wsp.derive_settings
    captured: dict = {}

    def derive_spy(segments, **kwargs):
        captured["input"] = segments
        out = real_derive(segments, **kwargs)
        captured["output"] = out
        return out

    monkeypatch.setattr(dp, "derive_settings", derive_spy)

    # Spy on author_pass: capture the clone primitive it receives.
    def author_spy(primitive, _strategy, **_kwargs):
        captured.setdefault("author_primitives", []).append(primitive)
        return "[0:00-0:03] Jade pushes in.\n[0:03-0:06] Wren braces beside Jade."

    monkeypatch.setattr(dp, "author_pass", author_spy)

    # ── enabled: settings injected into the clone, original untouched ──
    monkeypatch.setenv("RECOIL_WORLD_STATE_PASS", "1")
    dp.build_unified_payload(_ctx())

    clone = captured["author_primitives"][0]
    assert [s.get("setting") for s in clone.timing_segments] == [
        "Bay east: Jade by the hatch.",
        "Bay east: Wren braces beside Jade.",
    ]
    # The clone is the dataclasses.replace() object built from derived segments.
    assert clone.timing_segments is captured["output"]
    # The original input segments are a different list, never mutated in place.
    assert captured["input"] is not captured["output"]
    assert all("setting" not in s for s in captured["input"])

    # ── disabled: derive_settings is never called ──
    monkeypatch.delenv("RECOIL_WORLD_STATE_PASS", raising=False)
    calls: list = []
    monkeypatch.setattr(
        dp,
        "derive_settings",
        lambda segments, **kwargs: (calls.append(segments), segments)[1],
    )
    dp.build_unified_payload(_ctx())
    assert calls == []


def test_sublocation_prefix_repaired(monkeypatch):
    """Merge-gate r11: settings for sublocation-tagged segments must lead with it."""
    monkeypatch.setattr(
        wsp, "_call_world_state_model", lambda *a, **k: "Jade beside the open pod."
    )
    segs = [{"duration_s": 3.0, "sublocation": "pod_platform"}]
    out = wsp.derive_settings(segs, location_id="shaft", char_ids=["JADE"])
    assert out[0]["setting"].lower().startswith("pod platform")
