from __future__ import annotations

import json
from pathlib import Path

import pytest

from recoil.core.paths import ProjectPaths
from recoil.pipeline._lib import breakdown_extract
from recoil.pipeline._lib.breakdown_extract import (
    BreakdownExtractError,
    extract_mention_ledger,
)


SCRIPT_FIXTURE = """# Episode 1

INT. LOWER DECKS - NIGHT
JADE enters the corridor.
The cracked cryo pod hisses.

INT. MAINTENANCE SHAFT - NIGHT
WREN climbs onto the pod platform.
Jade removes her torn jacket.

EXT. HULL GANTRY - DAWN
Jade sees Wren's blue eyes in the glare.
"""


def _make_project(tmp_path: Path, script_text: str = SCRIPT_FIXTURE) -> ProjectPaths:
    root = tmp_path / "tartarus"
    (root / "scripting" / "episodes").mkdir(parents=True)
    (root / "scripting" / "episodes" / "ep_001.md").write_text(
        script_text,
        encoding="utf-8",
    )
    bible_path = root / "_pipeline" / "state" / "visual" / "global_bible.json"
    bible_path.parent.mkdir(parents=True)
    bible_path.write_text(
        json.dumps(
            {
                "characters": {"jade": {}, "wren": {}},
                "locations": {
                    "lower_decks": {},
                    "maintenance_shaft": {
                        "sublocations": {"pod_platform": {"description": "Platform."}}
                    },
                    "hull_gantry": {},
                },
                "props": {"cryo_pod": {}},
            }
        ),
        encoding="utf-8",
    )
    return ProjectPaths.from_root(root)


def _patch_paths(monkeypatch: pytest.MonkeyPatch, paths: ProjectPaths) -> None:
    monkeypatch.setattr(
        breakdown_extract.ProjectPaths,
        "for_project",
        classmethod(lambda cls, project: paths),
    )


def _response_for_prompt(prompt: str, calls: list[str]) -> str:
    payload = json.loads(prompt)
    scene_id = payload["scene_id"]
    calls.append(scene_id)
    if scene_id == "EP001_SC001":
        return json.dumps(
            {
                "mentions": [
                    {
                        "kind": "character",
                        "surface_text": "JADE",
                        "character_id": "jade",
                        "span_quote": "JADE enters the corridor.",
                    },
                    {
                        "kind": "location",
                        "surface_text": "LOWER DECKS",
                        "location_id": "lower_decks",
                        "span_quote": "INT. LOWER DECKS - NIGHT",
                    },
                    {
                        "kind": "prop_state",
                        "surface_text": "cracked cryo pod",
                        "prop_id": "cryo_pod",
                        "state_id": "cracked",
                        "span_quote": "The cracked cryo pod hisses.",
                    },
                ]
            }
        )
    if scene_id == "EP001_SC002":
        return json.dumps(
            {
                "mentions": [
                    {
                        "kind": "character",
                        "surface_text": "WREN",
                        "character_id": "wren",
                        "span_quote": "WREN climbs onto the pod platform.",
                    },
                    {
                        "kind": "sublocation",
                        "surface_text": "pod platform",
                        "location_id": "maintenance_shaft",
                        "sublocation": "pod_platform",
                        "span_quote": "WREN climbs onto the pod platform.",
                    },
                    {
                        "kind": "wardrobe_change",
                        "surface_text": "removes her torn jacket",
                        "character_id": "jade",
                        "piece": "jacket",
                        "change": "removed",
                        "span_quote": "Jade removes her torn jacket.",
                    },
                ]
            }
        )
    return json.dumps(
        {
            "mentions": [
                {
                    "kind": "identity_observation",
                    "surface_text": "Wren's blue eyes",
                    "character_id": "wren",
                    "attribute": "eye_color",
                    "observed_value": "blue",
                    "span_quote": "Jade sees Wren's blue eyes in the glare.",
                }
            ]
        }
    )


def _patch_llm(monkeypatch: pytest.MonkeyPatch, calls: list[str]) -> None:
    monkeypatch.setattr(breakdown_extract, "get_model", lambda role, category: "test-model")

    def fake_call(model: str, system_prompt: str, user_prompt: str) -> str:
        assert model == "test-model"
        assert "S1 continuity mention extractor" in system_prompt
        assert "normalization_vocabulary" in user_prompt
        return _response_for_prompt(user_prompt, calls)

    monkeypatch.setattr(breakdown_extract, "_call_extraction_model", fake_call)


def test_full_extraction_writes_three_scene_ledger(monkeypatch, tmp_path):
    paths = _make_project(tmp_path)
    _patch_paths(monkeypatch, paths)
    calls: list[str] = []
    _patch_llm(monkeypatch, calls)

    ledger = extract_mention_ledger("tartarus", 1, model="test-model")

    assert calls == ["EP001_SC001", "EP001_SC002", "EP001_SC003"]
    assert ledger["schema_version"] == 1
    assert "Gate-A-ratified CACHE" in ledger["_meta"]["note"]
    assert ledger["project"] == "tartarus"
    assert ledger["episode"] == 1
    assert len(ledger["script_content_hash"]) == 64
    assert len(ledger["scenes"]) == 3
    assert [scene["carried_forward"] for scene in ledger["scenes"]] == [False, False, False]
    assert (paths.episode_breakdown_dir(1) / "mention_ledger.json").is_file()

    for scene in ledger["scenes"]:
        for mention in scene["mentions"]:
            _assert_required_shape(mention)
            assert mention["scene_id"] == scene["scene_id"]
            assert mention["scene_hash"] == scene["scene_hash"]


def test_carry_forward_only_reextracts_changed_scene(monkeypatch, tmp_path):
    paths = _make_project(tmp_path)
    _patch_paths(monkeypatch, paths)
    initial_calls: list[str] = []
    _patch_llm(monkeypatch, initial_calls)
    extract_mention_ledger("tartarus", 1, model="test-model")

    changed = SCRIPT_FIXTURE.replace(
        "WREN climbs onto the pod platform.",
        "WREN climbs onto the pod platform with a damaged glove.",
    )
    (paths.episodes_dir / "ep_001.md").write_text(changed, encoding="utf-8")

    second_calls: list[str] = []

    def fake_changed_call(model: str, system_prompt: str, user_prompt: str) -> str:
        payload = json.loads(user_prompt)
        second_calls.append(payload["scene_id"])
        return json.dumps(
            {
                "mentions": [
                    {
                        "kind": "transient_state",
                        "surface_text": "damaged glove",
                        "character_id": "wren",
                        "state_desc": "wearing a damaged glove",
                        "span_quote": "WREN climbs onto the pod platform with a damaged glove.",
                    }
                ]
            }
        )

    monkeypatch.setattr(breakdown_extract, "_call_extraction_model", fake_changed_call)

    ledger = extract_mention_ledger("tartarus", 1, model="test-model")

    assert second_calls == ["EP001_SC002"]
    assert [scene["carried_forward"] for scene in ledger["scenes"]] == [True, False, True]
    assert ledger["scenes"][0]["mentions"][0]["character_id"] == "jade"
    assert ledger["scenes"][2]["mentions"][0]["attribute"] == "eye_color"
    assert ledger["scenes"][1]["mentions"][0]["kind"] == "transient_state"


def test_missing_kind_required_field_raises(monkeypatch, tmp_path):
    paths = _make_project(tmp_path)
    _patch_paths(monkeypatch, paths)

    def bad_call(model: str, system_prompt: str, user_prompt: str) -> str:
        return json.dumps(
            {
                "mentions": [
                    {
                        "kind": "prop_state",
                        "surface_text": "cracked cryo pod",
                        "prop_id": "cryo_pod",
                        "span_quote": "The cracked cryo pod hisses.",
                    }
                ]
            }
        )

    monkeypatch.setattr(breakdown_extract, "_call_extraction_model", bad_call)

    with pytest.raises(BreakdownExtractError, match="state_id"):
        extract_mention_ledger("tartarus", 1, model="test-model")


def test_dry_run_zero_writes(monkeypatch, tmp_path):
    paths = _make_project(tmp_path)
    _patch_paths(monkeypatch, paths)
    calls: list[str] = []
    _patch_llm(monkeypatch, calls)

    ledger = extract_mention_ledger("tartarus", 1, model="test-model", write=False)

    assert len(ledger["scenes"]) == 3
    assert calls == ["EP001_SC001", "EP001_SC002", "EP001_SC003"]
    assert not paths.episode_breakdown_dir(1).exists()


def test_extraction_failure_raises_without_empty_ledger_write(monkeypatch, tmp_path):
    paths = _make_project(tmp_path)
    _patch_paths(monkeypatch, paths)

    def exploding_call(model: str, system_prompt: str, user_prompt: str) -> str:
        raise RuntimeError("transport down")

    monkeypatch.setattr(breakdown_extract, "_call_extraction_model", exploding_call)

    with pytest.raises(BreakdownExtractError, match="EP001_SC001"):
        extract_mention_ledger("tartarus", 1, model="test-model")

    assert not (paths.episode_breakdown_dir(1) / "mention_ledger.json").exists()


def test_cli_dry_run_prints_counts_and_writes_nothing(monkeypatch, tmp_path, capsys):
    paths = _make_project(tmp_path)
    _patch_paths(monkeypatch, paths)
    calls: list[str] = []
    _patch_llm(monkeypatch, calls)

    from recoil.pipeline.tools import breakdown_extract_cli

    rc = breakdown_extract_cli.main(["--project", "tartarus", "--episode", "1", "--dry-run"])

    out = capsys.readouterr().out
    assert rc == 0
    assert "DRY-RUN mention_ledger" in out
    assert "scenes=3" in out
    assert "mentions=7" in out
    assert "writes=0" in out
    assert not paths.episode_breakdown_dir(1).exists()


def _assert_required_shape(mention: dict) -> None:
    for key in ("kind", "surface_text", "scene_id", "scene_hash", "span_quote"):
        assert mention.get(key)

    required = {
        "character": ("character_id",),
        "location": ("location_id",),
        "sublocation": ("location_id", "sublocation"),
        "prop": ("prop_id",),
        "prop_state": ("prop_id", "state_id"),
        "wardrobe_change": ("character_id", "piece", "change"),
        "transient_state": ("character_id", "state_desc"),
        "identity_observation": ("character_id", "attribute", "observed_value"),
    }[mention["kind"]]
    for key in required:
        assert mention.get(key)
