from __future__ import annotations

import sys
import types

import pytest

from recoil.core import anthropic_client as anthropic_client_mod
from recoil.core import claude_cli
from recoil.pipeline._lib import world_state_pass as wsp


@pytest.fixture(autouse=True)
def _reset_world_state_breaker() -> None:
    wsp._reset_breaker_state()


def _segments() -> list[dict]:
    return [
        {"shot_id": "S1", "intent": "Jade opens the pod.", "sublocation": "pod_platform"},
        {"shot_id": "S2", "intent": "Wren sits up."},
    ]


class _TextBlock:
    def __init__(self, text: str) -> None:
        self.text = text


class _Usage:
    input_tokens = 100
    output_tokens = 20


class _Response:
    def __init__(self, text: str) -> None:
        self.content = [_TextBlock(text)]
        self.usage = _Usage()


class _Messages:
    def __init__(self, text: str) -> None:
        self.text = text
        self.calls: list[dict] = []

    def create(self, **kwargs):
        self.calls.append(kwargs)
        return _Response(self.text)


class _Client:
    def __init__(self, text: str) -> None:
        self.messages = _Messages(text)


def test_world_state_cli_transport_uses_claude_cli_and_applies_settings(
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    cli_calls: list[dict] = []

    def fake_cli_call(prompt, images=None, *, system_prompt=None, model=None, timeout_s=900):
        cli_calls.append(
            {
                "prompt": prompt,
                "images": images,
                "system_prompt": system_prompt,
                "model": model,
                "timeout_s": timeout_s,
            }
        )
        return "Pod platform: Jade beside the open pod.\nWren seated inside it."

    monkeypatch.setattr(claude_cli, "claude_transport", lambda: "cli")
    monkeypatch.setattr(claude_cli, "claude_cli_call", fake_cli_call)
    monkeypatch.setattr(
        anthropic_client_mod,
        "anthropic_client",
        lambda: pytest.fail("SDK client must not be constructed on CLI transport"),
    )

    result = wsp.derive_settings(
        _segments(),
        location_id="shaft",
        char_ids=["JADE", "WREN"],
        model="claude-test",
    )

    assert cli_calls == [
        {
            "prompt": cli_calls[0]["prompt"],
            "images": None,
            "system_prompt": wsp.WORLD_STATE_SYSTEM,
            "model": "claude-test",
            "timeout_s": 900,
        }
    ]
    assert '"segment_count": 2' in cli_calls[0]["prompt"]
    assert [segment["setting"] for segment in result] == [
        "Pod platform: Jade beside the open pod.",
        "Wren seated inside it.",
    ]


def test_world_state_sdk_transport_uses_sdk_path(
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    monkeypatch.setattr(claude_cli, "claude_transport", lambda: "sdk")
    monkeypatch.setattr(
        claude_cli,
        "claude_cli_call",
        lambda *_a, **_k: pytest.fail("CLI call must not run on SDK transport"),
    )
    monkeypatch.setitem(sys.modules, "anthropic", types.SimpleNamespace())
    client = _Client("SDK-authored setting.")
    monkeypatch.setattr(anthropic_client_mod, "anthropic_client", lambda: client)

    text = wsp._call_world_state_model("claude-sdk", "system text", "user text")

    assert text == "SDK-authored setting."
    assert client.messages.calls == [
        {
            "model": "claude-sdk",
            "max_tokens": 1024,
            "system": "system text",
            "messages": [{"role": "user", "content": "user text"}],
        }
    ]


def test_world_state_breaker_raises_on_third_consecutive_exception(
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    def boom(*_args, **_kwargs):
        raise RuntimeError("model unavailable")

    monkeypatch.setattr(wsp, "_call_world_state_model", boom)

    first = _segments()
    second = _segments()
    assert wsp.derive_settings(first, location_id=None, char_ids=[], model="m") is first
    assert wsp.derive_settings(second, location_id=None, char_ids=[], model="m") is second

    with pytest.raises(wsp.WorldStatePassOutage, match="RuntimeError.*systemically failing"):
        wsp.derive_settings(_segments(), location_id=None, char_ids=[], model="m")


def test_world_state_line_count_mismatch_counts_toward_breaker(
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    def boom(*_args, **_kwargs):
        raise RuntimeError("model unavailable")

    monkeypatch.setattr(wsp, "_call_world_state_model", boom)
    assert wsp.derive_settings(_segments(), location_id=None, char_ids=[], model="m")
    assert wsp.derive_settings(_segments(), location_id=None, char_ids=[], model="m")

    monkeypatch.setattr(wsp, "_call_world_state_model", lambda *_a, **_k: "one line")

    with pytest.raises(wsp.WorldStatePassOutage, match="malformed output.*systemically failing"):
        wsp.derive_settings(_segments(), location_id=None, char_ids=[], model="m")


def test_world_state_success_resets_breaker(
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    def boom(*_args, **_kwargs):
        raise RuntimeError("model unavailable")

    monkeypatch.setattr(wsp, "_call_world_state_model", boom)
    assert wsp.derive_settings(_segments(), location_id=None, char_ids=[], model="m")
    assert wsp.derive_settings(_segments(), location_id=None, char_ids=[], model="m")

    monkeypatch.setattr(
        wsp,
        "_call_world_state_model",
        lambda *_a, **_k: "Pod platform: Jade beside the pod.\nWren seated inside.",
    )
    result = wsp.derive_settings(_segments(), location_id=None, char_ids=[], model="m")
    assert [segment["setting"] for segment in result] == [
        "Pod platform: Jade beside the pod.",
        "Wren seated inside.",
    ]

    monkeypatch.setattr(wsp, "_call_world_state_model", boom)
    assert wsp.derive_settings(_segments(), location_id=None, char_ids=[], model="m")
    assert wsp.derive_settings(_segments(), location_id=None, char_ids=[], model="m")
