from __future__ import annotations

import json
from pathlib import Path

from PIL import Image
import pytest

from recoil.core.paths import ProjectPaths
from recoil.pipeline._lib import board_builder as bb
from recoil.pipeline._lib import story_gate as sg
from recoil.pipeline.core.persistence import (
    SceneVersionConflictError,
    load_manifest,
    load_scene,
    save_scene,
    scene_path,
)
from recoil.pipeline.core.receipts import GenerationReceipt
from recoil.pipeline.core.registry import MODALITY_STORYBOARD, RunResult
from recoil.pipeline.core.scene_version_store import SceneVersionStore
from recoil.pipeline.core.take import Beat, Scene


def _shot(n: int) -> dict:
    return {
        "shot_id": f"EP001_SH{n:02d}",
        "scene_index": 1,
        "duration_s": 1.0,
        "intent": f"Beat {n} action.",
        "asset_data": {"characters": [], "location_id": None},
        "spatial_data": {},
    }


@pytest.fixture()
def project_paths(tmp_path, monkeypatch):
    paths = ProjectPaths(project_root=tmp_path / "fixture_project")
    paths.project_root.mkdir(parents=True)
    monkeypatch.setattr(
        ProjectPaths,
        "for_project",
        classmethod(lambda cls, project=None: paths),
    )
    return paths


def _write_batch_scene(project_paths: ProjectPaths) -> Path:
    shots = [_shot(10), _shot(11)]
    beat = Beat(
        beat_id="BATCH_004",
        beat_metadata={
            "scene_id": "BATCH_004",
            "modality": "r2v_multi",
            "shot": shots[0],
            "batch_shots": shots,
            "batch_summary": {"shared_characters": []},
        },
    )
    scene = Scene(
        scene_id="BATCH_004",
        beats=[beat],
        scene_metadata={"episode": "ep_001", "project": "fixture_project"},
    )
    path = scene_path("fixture_project", "ep_001", "BATCH_004")
    save_scene(scene, path)
    return path


def _settings_passthrough(segments, **_kwargs):
    return [dict(seg, setting=f"Setting {i}") for i, seg in enumerate(segments, start=1)]


def _receipt(success: bool = True) -> GenerationReceipt:
    return GenerationReceipt(
        receipt_id="rcpt_test",
        modality=MODALITY_STORYBOARD,
        caller_id="board_builder",
        project="fixture_project",
        episode=1,
        shot_id="EP001_CONT_004",
        timestamp_utc="2026-06-11T00:00:00Z",
        run_result=RunResult(
            id="run_test",
            modality=MODALITY_STORYBOARD,
            output_path="/tmp/board.png" if success else None,
            metadata={},
            success=success,
            error=None if success else "boom",
        ),
    )


def _dispatch_spy(prompts: list[str]):
    def fake_dispatch(modality, payload, *, context):
        prompts.append(payload["prompt"])
        storyboards_dir = Path(payload["save_dir"])
        storyboards_dir.mkdir(parents=True, exist_ok=True)
        png_path = storyboards_dir / f"{payload['filename_stem']}.png"
        Image.new("RGB", (40, 40), color=(32, 96, 160)).save(png_path)
        Path(f"{png_path}.json").write_text(
            json.dumps(payload["sidecar_extra"]),
            encoding="utf-8",
        )
        return _receipt(True)

    return fake_dispatch


def _forced_checks(*, failed: bool = False, severity: str = "HARD") -> dict:
    checks = {}
    for name in (
        "depicts_beat",
        "spatially_possible",
        "eyeline_consistent",
        "object_of_gaze_in_frame_and_front",
        "causal_setup_present",
    ):
        checks[name] = {
            "passed": not failed if name == "spatially_possible" else True,
            "severity": severity if name == "spatially_possible" else "SOFT",
            "confidence": 0.9,
            "reason": "needs repair" if failed and name == "spatially_possible" else "ok",
        }
    return checks


def _verdict(
    *,
    route: str,
    failed: bool = False,
    severity: str = "HARD",
    injectable: bool = True,
    fix_hint: str = "Put the pod in front of Jade.",
) -> dict:
    return {
        "text_stageability": None,
        "panels": [
            {
                "index": 1,
                "description": "Panel 1 visible.",
                "forced_checks": _forced_checks(failed=failed, severity=severity),
                "fix_hint": fix_hint if failed else None,
                "fix_hint_injectable": injectable if failed else False,
            },
            {
                "index": 2,
                "description": "Panel 2 visible.",
                "forced_checks": _forced_checks(),
                "fix_hint": None,
                "fix_hint_injectable": False,
            },
        ],
        "transitions": [
            {
                "from": 1,
                "to": 2,
                "forced_checks": {
                    "causal_setup_present": {
                        "passed": True,
                        "severity": "HARD",
                        "confidence": 0.9,
                        "reason": "ok",
                    }
                },
                "fix_hint": None,
                "fix_hint_injectable": False,
            }
        ],
        "routing": {"class": route, "confidence": 0.9, "evidence": route},
    }


def _install_common(project_paths, monkeypatch, image_verdicts: list[dict | Exception]):
    _write_batch_scene(project_paths)
    prompts: list[str] = []
    monkeypatch.setenv("RECOIL_STORY_GATE", "shadow")
    monkeypatch.setattr(bb, "derive_settings", _settings_passthrough)
    monkeypatch.setattr(bb, "dispatch", _dispatch_spy(prompts))

    def judge(prompt, images, *, model=None, max_attempts=3):
        if "TEXT-STAGEABILITY" in prompt:
            return {"stageable": True, "findings": []}
        verdict = image_verdicts.pop(0)
        if isinstance(verdict, Exception):
            raise verdict
        return verdict

    monkeypatch.setattr(sg, "_judge_call", judge)
    return prompts


def _downstream(manifest: dict, version: int) -> str:
    return next(
        entry for entry in manifest["versions"] if entry["version"] == version
    )["downstream"]


def test_rerollable_hard_injects_fix_note_and_proposes_clean_second_attempt(
    project_paths,
    monkeypatch,
):
    prompts = _install_common(
        project_paths,
        monkeypatch,
        [
            _verdict(route="board_problem", failed=True, injectable=True),
            _verdict(route="ok"),
        ],
    )

    result = bb.build_with_auto_reroll(
        "fixture_project",
        1,
        "EP001_CONT_004",
        step_runner=object(),
    )

    board = load_scene(scene_path("fixture_project", "ep_001", "BATCH_004")).beats[0].board
    assert result["success"] is True
    assert result["attempts"] == 2
    assert result["stopped_reason"] is None
    assert result["fingerprint_version"] == bb.BOARD_FINGERPRINT_VERSION
    assert "Panel 1: Put the pod in front of Jade." in prompts[1]
    assert board["status"] == "proposed"
    assert board["artifact"].endswith("EP001_CONT_004_v02.png")
    assert board["fingerprint_version"] == bb.BOARD_FINGERPRINT_VERSION
    assert board["reroll_attempts"] == 2


def test_auto_reroll_success_marks_active_version_derived(project_paths, monkeypatch):
    scene_file = _write_batch_scene(project_paths)
    active = load_scene(scene_file)
    store = SceneVersionStore("fixture_project", "ep_001")
    store.write_scene_candidate("BATCH_004", active)
    store.conform("BATCH_004", 2)
    assert _downstream(load_manifest("fixture_project", "ep_001", "BATCH_004"), 2) == "not_derived"

    monkeypatch.setenv("RECOIL_STORY_GATE", "shadow")
    monkeypatch.setattr(bb, "derive_settings", _settings_passthrough)
    monkeypatch.setattr(bb, "dispatch", _dispatch_spy([]))
    monkeypatch.setattr(
        sg,
        "_judge_call",
        lambda prompt, images, *, model=None, max_attempts=3: (
            {"stageable": True, "findings": []}
            if "TEXT-STAGEABILITY" in prompt
            else _verdict(route="ok")
        ),
    )

    result = bb.build_with_auto_reroll(
        "fixture_project",
        1,
        "EP001_CONT_004",
        step_runner=object(),
    )

    assert result["success"] is True
    assert _downstream(load_manifest("fixture_project", "ep_001", "BATCH_004"), 2) == "derived"


def test_auto_reroll_active_version_move_before_second_attempt_raises_before_dispatch(
    project_paths,
    monkeypatch,
):
    scene_file = _write_batch_scene(project_paths)
    store = SceneVersionStore("fixture_project", "ep_001")
    store.write_scene_candidate("BATCH_004", load_scene(scene_file))
    prompts: list[str] = []
    monkeypatch.setenv("RECOIL_STORY_GATE", "shadow")
    monkeypatch.setattr(bb, "derive_settings", _settings_passthrough)
    monkeypatch.setattr(bb, "dispatch", _dispatch_spy(prompts))
    image_judges = 0

    def judge(prompt, images, *, model=None, max_attempts=3):
        nonlocal image_judges
        if "TEXT-STAGEABILITY" in prompt:
            return {"stageable": True, "findings": []}
        image_judges += 1
        if image_judges == 1:
            store.conform("BATCH_004", 2)
            return _verdict(route="board_problem", failed=True, injectable=True)
        return _verdict(route="ok")

    monkeypatch.setattr(sg, "_judge_call", judge)

    # REC-231: the second paid attempt must keep the version captured at loop start.
    with pytest.raises(SceneVersionConflictError):
        bb.build_with_auto_reroll(
            "fixture_project",
            1,
            "EP001_CONT_004",
            step_runner=object(),
        )

    assert len(prompts) == 1
    assert bb.load_scene_active("fixture_project", "ep_001", "BATCH_004").beats[0].board is None


def test_auto_reroll_active_version_move_before_finalize_raises_without_write(
    project_paths,
    monkeypatch,
):
    scene_file = _write_batch_scene(project_paths)
    store = SceneVersionStore("fixture_project", "ep_001")
    store.write_scene_candidate("BATCH_004", load_scene(scene_file))
    prompts: list[str] = []
    monkeypatch.setenv("RECOIL_STORY_GATE", "shadow")
    monkeypatch.setattr(bb, "derive_settings", _settings_passthrough)
    monkeypatch.setattr(bb, "dispatch", _dispatch_spy(prompts))
    image_judges = 0

    def judge(prompt, images, *, model=None, max_attempts=3):
        nonlocal image_judges
        if "TEXT-STAGEABILITY" in prompt:
            return {"stageable": True, "findings": []}
        image_judges += 1
        if image_judges == 1:
            store.conform("BATCH_004", 2)
            return _verdict(route="ok")
        raise AssertionError("finalize conflict should stop after first attempt")

    monkeypatch.setattr(sg, "_judge_call", judge)

    # REC-231: finalize must not attach an attempt generated for v1 to newly-active v2.
    with pytest.raises(SceneVersionConflictError):
        bb.build_with_auto_reroll(
            "fixture_project",
            1,
            "EP001_CONT_004",
            step_runner=object(),
        )

    assert len(prompts) == 1
    assert bb.load_scene_active("fixture_project", "ep_001", "BATCH_004").beats[0].board is None


def test_non_injectable_hard_failure_does_not_reroll(project_paths, monkeypatch):
    _install_common(
        project_paths,
        monkeypatch,
        [_verdict(route="board_problem", failed=True, injectable=False)],
    )

    result = bb.build_with_auto_reroll(
        "fixture_project",
        1,
        "EP001_CONT_004",
        step_runner=object(),
    )

    board = load_scene(scene_path("fixture_project", "ep_001", "BATCH_004")).beats[0].board
    assert result["attempts"] == 1
    assert result["stopped_reason"] == "non_injectable_hard_fail"
    assert board["status"] == "rejected"


def test_script_problem_does_not_reroll(project_paths, monkeypatch):
    _install_common(
        project_paths,
        monkeypatch,
        [_verdict(route="script_problem", failed=True, injectable=False)],
    )

    result = bb.build_with_auto_reroll(
        "fixture_project",
        1,
        "EP001_CONT_004",
        step_runner=object(),
    )

    assert result["attempts"] == 1
    assert result["stopped_reason"] == "non_rerollable_route:script_problem"


def test_judge_unavailable_on_second_attempt_proposes_first_attempt(
    project_paths,
    monkeypatch,
):
    _install_common(
        project_paths,
        monkeypatch,
        [
            _verdict(route="board_problem", failed=True, injectable=True),
            sg.StoryGateJudgeUnavailable("rate limited"),
        ],
    )

    result = bb.build_with_auto_reroll(
        "fixture_project",
        1,
        "EP001_CONT_004",
        step_runner=object(),
    )

    board = load_scene(scene_path("fixture_project", "ep_001", "BATCH_004")).beats[0].board
    assert result["attempts"] == 2
    assert result["stopped_reason"] == "judge_unavailable"
    assert result["artifact"].endswith("EP001_CONT_004_v01.png")
    assert board["status"] == "proposed"
    assert board["artifact"].endswith("EP001_CONT_004_v01.png")


def test_attempt_cap_rejected_and_lineage_recorded(project_paths, monkeypatch):
    _install_common(
        project_paths,
        monkeypatch,
        [
            _verdict(route="board_problem", failed=True, injectable=True, fix_hint="Fix A."),
            _verdict(route="board_problem", failed=True, injectable=True, fix_hint="Fix B."),
        ],
    )

    result = bb.build_with_auto_reroll(
        "fixture_project",
        1,
        "EP001_CONT_004",
        step_runner=object(),
        max_attempts=2,
    )

    board = load_scene(scene_path("fixture_project", "ep_001", "BATCH_004")).beats[0].board
    sidecar_path = project_paths.episode_storyboards_dir(1) / "EP001_CONT_004_v02.png.json"
    sidecar = json.loads(sidecar_path.read_text(encoding="utf-8"))

    assert result["attempts"] == 2
    assert result["stopped_reason"] == "attempt_cap_reached"
    assert result["reroll_lineage"] == [
        {
            "attempt": 1,
            "artifact": "prep/ep_001/storyboards/EP001_CONT_004_v01.png",
            "route": "board_problem",
            "hard_fails": 1,
            "soft_fails": 0,
        },
        {
            "attempt": 2,
            "artifact": "prep/ep_001/storyboards/EP001_CONT_004_v02.png",
            "route": "board_problem",
            "hard_fails": 1,
            "soft_fails": 0,
        },
    ]
    assert sidecar["reroll_lineage"] == result["reroll_lineage"]
    assert sidecar["fingerprint_version"] == bb.BOARD_FINGERPRINT_VERSION
    assert board["status"] == "rejected"
    assert board["artifact"].endswith("EP001_CONT_004_v02.png")
    assert board["fingerprint_version"] == bb.BOARD_FINGERPRINT_VERSION
    assert board["reroll_attempts"] == 2


def test_reroll_select_missing_fingerprint_version_raises(project_paths):
    _write_batch_scene(project_paths)
    selected = {
        "attempt": 1,
        "result": {
            "artifact": "prep/ep_001/storyboards/EP001_CONT_004_v01.png",
            "source_sha256": "abc123",
        },
    }

    with pytest.raises(bb.BoardBuilderError, match="missing fingerprint_version"):
        bb._finalize_reroll_board(
            "fixture_project",
            1,
            "EP001_CONT_004",
            selected,
            [selected],
            status="proposed",
            expected_version=1,
        )
