from __future__ import annotations

import asyncio
import dataclasses
import importlib.util
import json
from pathlib import Path

import pytest

from recoil.pipeline._lib import board_builder as bb
from recoil.pipeline._lib import derivation_manifest
from recoil.pipeline._lib.derivation_sha import shotset_hash
from recoil.pipeline._lib.plan_loader import CanonicalPlan, CanonicalShot
from recoil.pipeline.core.persistence import save_scene, scene_path
from recoil.pipeline.core.take import Beat, Scene
from recoil.pipeline.orchestrator import episode_runner as er
from recoil.pipeline.orchestrator.episode_runner import (
    BoardGateError,
    EpisodeRunner,
)


def _load_reroll_module():
    path = (
        Path(__file__).resolve().parents[1]
        / "api"
        / "routes"
        / "reroll.py"
    )
    spec = importlib.util.spec_from_file_location("_board_gate_reroll_route", path)
    module = importlib.util.module_from_spec(spec)
    assert spec.loader is not None
    spec.loader.exec_module(module)
    return module


@pytest.fixture(autouse=True)
def _project_root(tmp_path, monkeypatch):
    root = tmp_path / "projects"
    root.mkdir()
    (root / ".recoil-data-root").touch()
    project = root / "fixture"
    project.mkdir()
    monkeypatch.setenv("RECOIL_PROJECTS_ROOT", str(root))
    monkeypatch.delenv("RECOIL_BOARD_GATE", raising=False)
    monkeypatch.setattr(er, "load_project_config", lambda _project: {})
    return project


def _scene(*beats: Beat) -> Scene:
    return Scene(
        scene_id="BATCH_004",
        beats=list(beats),
        scene_metadata={"episode": "ep_001", "project": "fixture"},
    )


def _beat(
    beat_id: str = "BATCH_004",
    *,
    modality: str = "r2v_multi",
    board: dict | None = None,
) -> Beat:
    return Beat(
        beat_id=beat_id,
        max_takes=5,
        # REC-231 Phase 5: real beats always carry beat_metadata["scene_id"] (the batch
        # id); the freshness gate fails loud without it. Mirror production.
        beat_metadata={"modality": modality, "scene_id": beat_id},
        board=board,
    )


def _proposed_board(artifact: str = "prep/ep_001/storyboards/board.png") -> dict:
    return {
        "status": "proposed",
        "artifact": artifact,
        "source_sha256": "abc123",
        "approved_by": None,
        "updated_at": "2026-06-11T00:00:00Z",
    }


def _approved_board(artifact: str) -> dict:
    board = _proposed_board(artifact)
    board["status"] = "approved"
    board["approved_by"] = "JT"
    return board


def _grouped_beat(
    *,
    board: dict | None = None,
    shot_id: str = "EP001_SH10",
    include_hash: bool = True,
) -> Beat:
    shot = dataclasses.asdict(_canonical_shot(10))
    grouping = {
        "strategy": "continuity",
        "ordinal": 4,
        "shot_ids": [shot_id],
        "source_pass_id": None,
    }
    if include_hash:
        grouping["shotset_hash"] = shotset_hash([shot_id])
    return Beat(
        beat_id="BATCH_004",
        max_takes=5,
        beat_metadata={
            "scene_id": "BATCH_004",
            "modality": "r2v_multi",
            "shot": shot,
            "batch_shots": [shot],
            "batch_summary": {"shared_characters": []},
            "grouping": grouping,
        },
        board=board,
    )


def _gate_source_sha(beat: Beat, *, version: int = 2) -> str:
    primitive = bb._primitive_from_beat("fixture", beat.beat_id, beat)
    segments = list(getattr(primitive, "timing_segments", []) or [])
    return bb.compute_source_sha256(segments, version=version)


def _stamp_gate_record(
    beat: Beat,
    *,
    status: str = "approved",
    artifact: str,
    photoreal_artifact: str | None = None,
    needs_revalidation: bool | None = None,
) -> dict:
    grouping = (beat.beat_metadata or {}).get("grouping") or {}
    h = grouping.get("shotset_hash") or shotset_hash(grouping["shot_ids"])
    record = {
        "status": status,
        "artifact": artifact,
        "photoreal_artifact": photoreal_artifact,
        "source_sha256": _gate_source_sha(beat, version=2),
        "fingerprint_version": 2,
        "covered_shot_ids": list(grouping.get("shot_ids") or []),
        "approved_by": "JT",
        "updated_at": "2026-06-14T00:00:00Z",
    }
    if needs_revalidation is not None:
        record["needs_revalidation"] = needs_revalidation
    derivation_manifest.stamp_board("fixture", 1, h, record)
    return record


def _write_valid_png(path: Path) -> None:
    from PIL import Image

    path.parent.mkdir(parents=True, exist_ok=True)
    Image.effect_noise((256, 256), 50).convert("RGB").save(path)
    assert path.stat().st_size >= 1024


def _runner(monkeypatch):
    runner = EpisodeRunner(project="fixture", plan={}, episode="ep_001")
    attempts: list[str] = []

    async def _record(beat, *_args, **_kwargs):
        attempts.append(beat.beat_id)

    monkeypatch.setattr(runner, "_dispatch_one_beat", _record)
    return runner, attempts


def _canonical_shot(index: int) -> CanonicalShot:
    return CanonicalShot(
        shot_id=f"EP001_SH{index:02d}",
        scene_index=1,
        sequence_id=None,
        pipeline="video",
        previs_model=None,
        video_model="seeddance-2.0",
        location_id="LOC_1",
        characters=[],
        shot_type="MS",
        duration_s=2.0,
        is_env_only=False,
        has_dialogue=False,
        aspect_ratio="9:16",
        quality=None,
        cinematography=None,
        raw={},
    )


def _r2v_multi_plan() -> CanonicalPlan:
    shots = [_canonical_shot(index) for index in range(1, 4)]
    return CanonicalPlan(
        episode_id="ep_001",
        project="fixture",
        shots=shots,
        source_path=Path("fixture_plan.json"),
        raw={
            "episode_id": "ep_001",
            "project": "fixture",
            "shots": [{"shot_id": shot.shot_id} for shot in shots],
        },
    )


def _batch_runner(plan: CanonicalPlan) -> EpisodeRunner:
    return EpisodeRunner(
        project="fixture",
        plan=plan.raw,
        casting={},
        episode="ep_001",
        concurrency=1,
    )


def _trap_run_scene(calls: list[str]):
    async def _run_scene(scene, **_kwargs):  # noqa: ANN001, ANN003
        calls.append(scene.scene_id)
        raise AssertionError("run_scene must not be called in derive_only mode")

    return _run_scene


def test_gate_off_by_default_dispatch_proceeds(monkeypatch):
    runner, attempts = _runner(monkeypatch)

    asyncio.run(runner.run_scene(_scene(_beat())))

    assert attempts == ["BATCH_004"]


def test_env_force_on_blocks_unapproved_before_dispatch(monkeypatch):
    monkeypatch.setenv("RECOIL_BOARD_GATE", "1")
    runner, attempts = _runner(monkeypatch)

    with pytest.raises(BoardGateError) as exc:
        asyncio.run(runner.run_scene(_scene(_beat())))

    assert exc.value.beat_id == "BATCH_004"
    assert exc.value.reason == "no_board"
    assert attempts == []


def test_config_list_enables_per_episode(monkeypatch):
    monkeypatch.setattr(
        er,
        "load_project_config",
        lambda _project: {"board_gate_episodes": [1]},
    )
    runner, attempts = _runner(monkeypatch)

    with pytest.raises(BoardGateError):
        asyncio.run(runner.run_scene(_scene(_beat())))
    assert attempts == []

    runner.episode = "ep_002"
    asyncio.run(runner.run_scene(_scene(_beat("BATCH_005"))))
    assert attempts == ["BATCH_005"]


def test_env_force_off_wins_over_config(monkeypatch):
    monkeypatch.setenv("RECOIL_BOARD_GATE", "0")
    monkeypatch.setattr(
        er,
        "load_project_config",
        lambda _project: {"board_gate_episodes": [1]},
    )
    runner, attempts = _runner(monkeypatch)

    asyncio.run(runner.run_scene(_scene(_beat())))

    assert attempts == ["BATCH_004"]


def test_approved_finished_board_passes(_project_root, monkeypatch):
    # asset-ssot SYNTHESIS board->r2v contract: a board-gated r2v shot needs an
    # approved board WITH a photoreal finish; the loose raw board alone no longer
    # ships. preferred_board_artifact returns the finish -> preflight gate passes.
    artifact = Path("prep/ep_001/storyboards/EP001_CONT_004_v01.png")
    finish = Path("prep/ep_001/storyboards/EP001_CONT_004_v01_photoreal.png")
    _write_valid_png(_project_root / artifact)
    _write_valid_png(_project_root / finish)
    monkeypatch.setenv("RECOIL_BOARD_GATE", "1")
    runner, attempts = _runner(monkeypatch)
    board = _approved_board(str(artifact))
    board["photoreal_artifact"] = str(finish)

    asyncio.run(runner.run_scene(_scene(_beat(board=board))))

    assert attempts == ["BATCH_004"]


def test_approved_raw_only_board_blocks_at_gate(_project_root, monkeypatch):
    # Inverse of the contract (SYNTHESIS V3): an APPROVED board with only the raw
    # pencil artifact (no photoreal finish) is blocked at the preflight gate BEFORE
    # any spend -> BoardGateError(no_board).
    artifact = Path("prep/ep_001/storyboards/EP001_CONT_004_v01.png")
    _write_valid_png(_project_root / artifact)
    monkeypatch.setenv("RECOIL_BOARD_GATE", "1")
    runner, attempts = _runner(monkeypatch)

    with pytest.raises(BoardGateError) as exc:
        asyncio.run(runner.run_scene(_scene(_beat(board=_approved_board(str(artifact))))))

    assert exc.value.reason == "no_board"
    assert attempts == []


@pytest.mark.parametrize(
    "record_kwargs",
    [
        {"status": "rejected"},
        {"needs_revalidation": True},
    ],
)
def test_preflight_board_gate_blocks_ssot_smuggled_cache_approval(
    _project_root,
    monkeypatch,
    record_kwargs,
):
    artifact = Path("prep/ep_001/storyboards/EP001_CONT_004_v01.png")
    _write_valid_png(_project_root / artifact)
    beat = _grouped_beat(board=_approved_board(str(artifact)))
    _stamp_gate_record(beat, artifact=str(artifact), **record_kwargs)
    monkeypatch.setenv("RECOIL_BOARD_GATE", "1")
    runner, attempts = _runner(monkeypatch)

    with pytest.raises(BoardGateError) as exc:
        asyncio.run(runner.run_scene(_scene(beat)))

    assert exc.value.reason == "board_ssot_not_approved"
    assert attempts == []


def test_preflight_board_gate_carries_approved_board_from_ssot(
    _project_root,
    monkeypatch,
):
    artifact = Path("prep/ep_001/storyboards/EP001_CONT_004_v01.png")
    finish = Path("prep/ep_001/storyboards/EP001_CONT_004_v01_photoreal.png")
    _write_valid_png(_project_root / artifact)
    _write_valid_png(_project_root / finish)
    beat = _grouped_beat(board=None)
    record = _stamp_gate_record(
        beat, artifact=str(artifact), photoreal_artifact=str(finish))
    monkeypatch.setenv("RECOIL_BOARD_GATE", "1")
    runner, attempts = _runner(monkeypatch)

    asyncio.run(runner.run_scene(_scene(beat)))

    assert attempts == ["BATCH_004"]
    assert beat.board == bb.board_record_to_cache(record)


def test_preflight_board_gate_blocks_grouped_cache_without_ssot_record(
    _project_root,
    monkeypatch,
):
    artifact = Path("prep/ep_001/storyboards/EP001_CONT_004_v01.png")
    _write_valid_png(_project_root / artifact)
    beat = _grouped_beat(board=_approved_board(str(artifact)))
    monkeypatch.setenv("RECOIL_BOARD_GATE", "1")
    runner, attempts = _runner(monkeypatch)

    with pytest.raises(BoardGateError) as exc:
        asyncio.run(runner.run_scene(_scene(beat)))

    assert exc.value.reason == "board_ssot_not_approved"
    assert attempts == []


def test_proposed_board_blocks_with_reason(monkeypatch):
    monkeypatch.setenv("RECOIL_BOARD_GATE", "1")
    runner, attempts = _runner(monkeypatch)

    with pytest.raises(BoardGateError) as exc:
        asyncio.run(runner.run_scene(_scene(_beat(board=_proposed_board()))))

    assert exc.value.reason == "board_proposed"
    assert attempts == []


def test_two_beat_scene_prescans_before_any_dispatch(_project_root, monkeypatch):
    artifact = Path("prep/ep_001/storyboards/EP001_CONT_004_v01.png")
    finish = Path("prep/ep_001/storyboards/EP001_CONT_004_v01_photoreal.png")
    _write_valid_png(_project_root / artifact)
    _write_valid_png(_project_root / finish)
    monkeypatch.setenv("RECOIL_BOARD_GATE", "1")
    runner, attempts = _runner(monkeypatch)
    _b04 = _approved_board(str(artifact))
    _b04["photoreal_artifact"] = str(finish)
    approved = _beat("BATCH_004", board=_b04)
    unapproved = _beat("BATCH_005")

    with pytest.raises(BoardGateError) as exc:
        asyncio.run(runner.run_scene(_scene(approved, unapproved)))

    assert exc.value.beat_id == "BATCH_005"
    assert attempts == []


def test_derive_only_runs_free_with_gate_enabled(monkeypatch):
    monkeypatch.setenv("RECOIL_BOARD_GATE", "1")
    plan = _r2v_multi_plan()
    runner = _batch_runner(plan)
    run_scene_calls: list[str] = []
    monkeypatch.setattr(runner, "run_scene", _trap_run_scene(run_scene_calls))

    result = asyncio.run(
        runner.run_episode_batches(plan, derive_only=True, dry_run=False)
    )

    assert result["derive_only"] is True
    assert result["written"]
    assert run_scene_calls == []


def test_gate_not_called_on_derive_only(monkeypatch):
    calls = []
    monkeypatch.setattr(er, "_preflight_board_gate", lambda **kw: calls.append(kw))
    monkeypatch.setenv("RECOIL_BOARD_GATE", "1")
    plan = _r2v_multi_plan()
    runner = _batch_runner(plan)
    run_scene_calls: list[str] = []
    monkeypatch.setattr(runner, "run_scene", _trap_run_scene(run_scene_calls))

    result = asyncio.run(
        runner.run_episode_batches(plan, derive_only=True, dry_run=False)
    )

    assert result["derive_only"] is True
    assert calls == []
    assert run_scene_calls == []


def test_gate_still_blocks_paid_run_unapproved(monkeypatch):
    monkeypatch.setenv("RECOIL_BOARD_GATE", "1")
    plan = _r2v_multi_plan()
    runner = _batch_runner(plan)

    with pytest.raises(BoardGateError):
        asyncio.run(
            runner.run_episode_batches(plan, derive_only=False, dry_run=False)
        )


def test_reroll_route_maps_board_gate_to_409(monkeypatch):
    reroll_mod = _load_reroll_module()
    grouping = {
        "strategy": "continuity",
        "ordinal": 4,
        "shot_ids": [],
        "source_pass_id": None,
    }
    scene = Scene(
        scene_id="BATCH_004",
        beats=[
            Beat(
                beat_id="BATCH_004",
                max_takes=5,
                beat_metadata={
                    "modality": "r2v_multi",
                    "grouping": grouping,
                    "scene_id": "BATCH_004",
                },
            )
        ],
        scene_metadata={"grouping": grouping},
    )
    save_scene(scene, scene_path("fixture", "ep_001", "BATCH_004"))

    class FakeRunner:
        def _estimate_take_cost(self, _beat):
            return 1.0

        def prepare_beat_for_reroll(self, _scene, beat, *, expected_version=None):
            return {"beat_id": beat.beat_id, "next_take_index": 0}

        async def run_scene(self, *_args, **_kwargs):
            raise BoardGateError("BATCH_004", "no_board")

    monkeypatch.setattr(
        reroll_mod,
        "_runner_for_reroll",
        lambda _project, _episode: FakeRunner(),
    )

    response = reroll_mod.reroll(
        {"project": "fixture", "episode": 1, "batch_id": "EP001_CONT_004"}
    )

    assert response.status_code == 409
    assert json.loads(response.body) == {
        "error": "board_gate_blocked",
        "message": "Storyboard board gate blocked beat BATCH_004: no_board",
        "beat_id": "BATCH_004",
    }


def test_reroll_route_gate_runs_before_prepare(monkeypatch):
    """A gate-blocked reroll must write NOTHING: the pre-scan runs BEFORE
    prepare_beat_for_reroll (which clears stale primaries + persists)."""
    reroll_mod = _load_reroll_module()
    grouping = {
        "strategy": "continuity",
        "ordinal": 4,
        "shot_ids": [],
        "source_pass_id": None,
    }
    scene = Scene(
        scene_id="BATCH_004",
        beats=[
            Beat(
                beat_id="BATCH_004",
                max_takes=5,
                beat_metadata={
                    "modality": "r2v_multi",
                    "grouping": grouping,
                    "scene_id": "BATCH_004",
                },
            )
        ],
        scene_metadata={"grouping": grouping},
    )
    save_scene(scene, scene_path("fixture", "ep_001", "BATCH_004"))

    class FakeRunner:
        def _estimate_take_cost(self, _beat):
            return 1.0

        def prepare_beat_for_reroll(self, _scene, _beat, *, expected_version=None):
            raise AssertionError(
                "prepare_beat_for_reroll must not run when the board gate blocks"
            )

        async def run_scene(self, *_args, **_kwargs):
            raise AssertionError("run_scene must not run when the board gate blocks")

    monkeypatch.setattr(
        reroll_mod,
        "_runner_for_reroll",
        lambda _project, _episode: FakeRunner(),
    )

    def _blocking_gate(*, project, episode, beats):
        raise BoardGateError(beats[0].beat_id, "no_board")

    monkeypatch.setattr(reroll_mod, "_preflight_board_gate", _blocking_gate)

    response = reroll_mod.reroll(
        {"project": "fixture", "episode": 1, "batch_id": "EP001_CONT_004"}
    )
    assert response.status_code == 409
