from __future__ import annotations

from io import BytesIO
import json
from types import SimpleNamespace

import pytest
from PIL import Image

from recoil.pipeline._lib import story_gate as sg


def _png_bytes(color=(255, 0, 0)) -> bytes:
    image = Image.new("RGB", (20, 20), color)
    buffer = BytesIO()
    image.save(buffer, format="PNG")
    return buffer.getvalue()


def _write_board(path) -> None:
    image = Image.new("RGB", (40, 40))
    colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0)]
    boxes = [(0, 0, 20, 20), (20, 0, 40, 20), (0, 20, 20, 40), (20, 20, 40, 40)]
    for color, box in zip(colors, boxes):
        image.paste(color, box)
    image.save(path)


def _packet(tmp_path) -> sg.StoryGatePacket:
    png = tmp_path / "EP001_CONT_004_v03.png"
    _write_board(png)
    return sg.StoryGatePacket(
        board_id="EP001_CONT_004_v03",
        board_png=png,
        grid_cols=2,
        grid_rows=2,
        slots=4,
        generation_prompt="GENERATION PROMPT\nAUTHORING\n1. Jade faces pod.\n",
        beats_text="AUTHORING\n1. Jade faces pod.\n",
        scene_context="SCENE CONTEXT\nThe pod sparked.\n",
        panels=[{"index": 1}, {"index": 2}, {"index": 3}, {"index": 4}],
        source_sha256="abc123",
        character_descriptions=None,
    )


def _valid_verdict() -> dict:
    return {
        "schema_version": 999,
        "judge_model": "unstamped",
        "prompt_version": "unstamped",
        "board_id": "unstamped",
        "source_sha256": "unstamped",
        "text_stageability": None,
        "panels": [
            {
                "index": idx,
                "description": "Jade faces the pod.",
                "forced_checks": {
                    name: {
                        "passed": True,
                        "severity": "SOFT",
                        "confidence": 0.9,
                        "reason": "ok",
                    }
                    for name in (
                        "depicts_beat",
                        "spatially_possible",
                        "eyeline_consistent",
                        "object_of_gaze_in_frame_and_front",
                        "causal_setup_present",
                    )
                },
                "fix_hint_injectable": False,
            }
            for idx in (1, 2, 3, 4)
        ],
        "transitions": [
            {
                "from": i,
                "to": i + 1,
                "forced_checks": {
                    "causal_setup_present": {
                        "passed": True,
                        "severity": "SOFT",
                        "confidence": 0.9,
                        "reason": "ok",
                    }
                },
                "fix_hint_injectable": False,
            }
            for i in (1, 2, 3)
        ],
        "routing": {
            "class": "ok",
            "confidence": 0.9,
            "evidence": "No staging issue.",
        },
    }


def _response(payload: str):
    return SimpleNamespace(content=[SimpleNamespace(text=payload)])


class _FakeMessages:
    def __init__(self, effects):
        self.effects = list(effects)
        self.calls = []

    def create(self, **kwargs):
        self.calls.append(kwargs)
        effect = self.effects.pop(0)
        if isinstance(effect, Exception):
            raise effect
        return _response(effect)


class _FakeClient:
    def __init__(self, effects):
        self.messages = _FakeMessages(effects)


def test_evaluate_board_stamps_valid_json_and_omits_sampling_kwargs(tmp_path, monkeypatch):
    client = _FakeClient([json.dumps(_valid_verdict())])
    monkeypatch.setenv("RECOIL_CLAUDE_TRANSPORT", "sdk")
    monkeypatch.setattr(sg, "anthropic_client", lambda: client)
    packet = _packet(tmp_path)

    verdict = sg.StoryGate().evaluate_board(packet, model="fixture-model")

    assert verdict["schema_version"] == sg.SCHEMA_VERSION
    assert verdict["judge_model"] == "fixture-model"
    assert verdict["prompt_version"] == sg.PROMPT_VERSION
    assert verdict["board_id"] == packet.board_id
    assert verdict["source_sha256"] == packet.source_sha256
    assert sg.validate_verdict(verdict) == []

    call = client.messages.calls[0]
    assert call["model"] == "fixture-model"
    assert call["max_tokens"] == 8192
    assert "temperature" not in call
    assert "top_p" not in call
    assert "top_k" not in call
    content = call["messages"][0]["content"]
    assert [block["type"] for block in content[:-1]] == ["image"] * 5
    assert content[-1]["type"] == "text"


def test_rate_limit_twice_then_success_backs_off(monkeypatch):
    class FakeRateLimit(Exception):
        pass

    sleeps = []
    monkeypatch.setattr(sg.anthropic, "RateLimitError", FakeRateLimit)
    monkeypatch.setattr(sg.time, "sleep", lambda seconds: sleeps.append(seconds))
    client = _FakeClient(
        [
            FakeRateLimit("429 one"),
            FakeRateLimit("429 two"),
            '{"ok": true}',
        ]
    )
    monkeypatch.setenv("RECOIL_CLAUDE_TRANSPORT", "sdk")
    monkeypatch.setattr(sg, "anthropic_client", lambda: client)

    assert sg._judge_call("prompt", None, model="fixture-model") == {"ok": True}
    assert sleeps == [2, 8]


def test_persistent_rate_limit_raises_unavailable(monkeypatch):
    class FakeRateLimit(Exception):
        pass

    sleeps = []
    monkeypatch.setattr(sg.anthropic, "RateLimitError", FakeRateLimit)
    monkeypatch.setattr(sg.time, "sleep", lambda seconds: sleeps.append(seconds))
    client = _FakeClient(
        [
            FakeRateLimit("429 one"),
            FakeRateLimit("429 two"),
            FakeRateLimit("429 three"),
        ]
    )
    monkeypatch.setenv("RECOIL_CLAUDE_TRANSPORT", "sdk")
    monkeypatch.setattr(sg, "anthropic_client", lambda: client)

    with pytest.raises(sg.StoryGateJudgeUnavailable) as excinfo:
        sg._judge_call("prompt", None, model="fixture-model")

    assert "429 three" in excinfo.value.reason
    assert sleeps == [2, 8]


def test_malformed_json_then_valid_reask_succeeds(monkeypatch):
    client = _FakeClient(["not json", "prefix {\"ok\": true} suffix"])
    monkeypatch.setenv("RECOIL_CLAUDE_TRANSPORT", "sdk")
    monkeypatch.setattr(sg, "anthropic_client", lambda: client)

    assert sg._judge_call("prompt", None, model="fixture-model") == {"ok": True}
    assert len(client.messages.calls) == 2
    assert "Your previous reply was not valid JSON" in (
        client.messages.calls[1]["messages"][0]["content"][-1]["text"]
    )


def test_malformed_json_twice_raises_unavailable(monkeypatch):
    client = _FakeClient(["not json", "still not json"])
    monkeypatch.setenv("RECOIL_CLAUDE_TRANSPORT", "sdk")
    monkeypatch.setattr(sg, "anthropic_client", lambda: client)

    with pytest.raises(sg.StoryGateJudgeUnavailable, match="invalid JSON"):
        sg._judge_call("prompt", None, model="fixture-model")


def test_use_crops_controls_image_count(tmp_path, monkeypatch):
    packet = _packet(tmp_path)

    client = _FakeClient([json.dumps(_valid_verdict())])
    monkeypatch.setenv("RECOIL_CLAUDE_TRANSPORT", "sdk")
    monkeypatch.setattr(sg, "anthropic_client", lambda: client)
    sg.StoryGate().evaluate_board(packet, use_crops=False, model="fixture-model")
    content = client.messages.calls[0]["messages"][0]["content"]
    assert [block["type"] for block in content].count("image") == 1

    client = _FakeClient([json.dumps(_valid_verdict())])
    monkeypatch.setenv("RECOIL_CLAUDE_TRANSPORT", "sdk")
    monkeypatch.setattr(sg, "anthropic_client", lambda: client)
    sg.StoryGate().evaluate_board(packet, use_crops=True, model="fixture-model")
    content = client.messages.calls[0]["messages"][0]["content"]
    assert [block["type"] for block in content].count("image") == 1 + packet.slots


def test_schema_invalid_verdict_gets_one_reask(tmp_path, monkeypatch):
    invalid = _valid_verdict()
    invalid["routing"]["class"] = "bad_route"
    client = _FakeClient([json.dumps(invalid), json.dumps(_valid_verdict())])
    monkeypatch.setenv("RECOIL_CLAUDE_TRANSPORT", "sdk")
    monkeypatch.setattr(sg, "anthropic_client", lambda: client)

    verdict = sg.StoryGate().evaluate_board(_packet(tmp_path), model="fixture-model")

    assert verdict["routing"]["class"] == "ok"
    assert len(client.messages.calls) == 2
    assert "did not match the required verdict schema" in (
        client.messages.calls[1]["messages"][0]["content"][-1]["text"]
    )


def test_judge_unavailable_verdict_is_schema_valid(monkeypatch):
    monkeypatch.setattr(sg, "get_model", lambda role, category: "fixture-model")

    verdict = sg.judge_unavailable_verdict("board-1", "sha", "rate limited")

    assert verdict["judge_model"] == "fixture-model"
    assert verdict["routing"]["class"] == "judge_unavailable"
    assert sg.validate_verdict(verdict) == []


def test_images_are_base64_png_blocks_before_text(monkeypatch):
    client = _FakeClient(['{"ok": true}'])
    monkeypatch.setenv("RECOIL_CLAUDE_TRANSPORT", "sdk")
    monkeypatch.setattr(sg, "anthropic_client", lambda: client)
    images = [_png_bytes((1, 2, 3)), _png_bytes((4, 5, 6))]

    sg._judge_call("prompt", images, model="fixture-model")

    content = client.messages.calls[0]["messages"][0]["content"]
    assert [block["type"] for block in content] == ["image", "image", "text"]
    assert content[0]["source"]["media_type"] == "image/png"
    assert content[0]["source"]["type"] == "base64"
    assert content[0]["source"]["data"]
    assert content[-1]["text"] == "prompt"


def test_evaluate_text_fails_closed_on_malformed_contract(monkeypatch, tmp_path):
    """PR #76 gate r3: parseable-but-malformed text-stageability JSON (e.g.
    {}) must raise StoryGateJudgeUnavailable after one re-ask — never read
    as a clean ok verdict."""
    calls = []

    def judge(prompt, images, *, model=None, max_attempts=3):
        calls.append(prompt)
        return {}

    monkeypatch.setattr(sg, "_judge_call", judge)
    gate = sg.StoryGate()
    with pytest.raises(sg.StoryGateJudgeUnavailable):
        gate.evaluate_text(_packet(tmp_path), model="mock")
    assert len(calls) == 2
    assert "output contract" in calls[1]


def test_evaluate_text_recovers_on_valid_reask(monkeypatch, tmp_path):
    responses = [
        {"stageable": True},  # malformed: findings missing
        {"stageable": True, "findings": []},
    ]

    def judge(prompt, images, *, model=None, max_attempts=3):
        return responses.pop(0)

    monkeypatch.setattr(sg, "_judge_call", judge)
    gate = sg.StoryGate()
    verdict = gate.evaluate_text(_packet(tmp_path), model="mock")
    assert verdict["routing"]["class"] == "ok"


def test_evaluate_text_fails_closed_on_failed_finding_missing_confidence(monkeypatch, tmp_path):
    """PR #76 gate r4: a FAILED text finding without the full contract
    (confidence/problem_kind/suggested_script_question, typed fields) must
    fail closed, not route."""
    bad = {
        "stageable": False,
        "findings": [{
            "beat_index": 2,
            "check": "causal_setup_present",
            "passed": False,
            "severity": "HARD",
            "reason": "missing cause",
        }],
    }

    def judge(prompt, images, *, model=None, max_attempts=3):
        return bad

    monkeypatch.setattr(sg, "_judge_call", judge)
    with pytest.raises(sg.StoryGateJudgeUnavailable):
        sg.StoryGate().evaluate_text(_packet(tmp_path), model="mock")


def test_evaluate_board_rejects_partial_panel_coverage(monkeypatch, tmp_path):
    """PR #76 gate r6: a schema-valid verdict with panels=[] (or missing
    indexes) is PARTIAL, not ok — fail closed after the re-ask."""
    empty = {
        "text_stageability": None,
        "panels": [],
        "transitions": [],
        "routing": {"class": "ok", "confidence": 0.9, "evidence": "looks fine"},
    }

    def judge(prompt, images, *, model=None, max_attempts=3):
        return dict(empty)

    monkeypatch.setattr(sg, "_judge_call", judge)
    with pytest.raises(sg.StoryGateJudgeUnavailable) as exc:
        sg.StoryGate().evaluate_board(_packet(tmp_path), model="mock")
    assert "coverage" in str(exc.value)


def test_no_crops_prompt_declares_single_image(tmp_path):
    """PR #76 gate r6: the no-crops calibration mode must not describe crop
    images that are not attached."""
    packet = _packet(tmp_path)
    with_crops = sg.build_image_judge_prompt(packet, use_crops=True)
    without = sg.build_image_judge_prompt(packet, use_crops=False)
    assert "panel 1 crop" in with_crops
    assert "panel 1 crop" not in without
    assert "Exactly ONE image is attached" in without
