from __future__ import annotations

import dataclasses
import json
from pathlib import Path
from types import SimpleNamespace

import pytest

from recoil.pipeline._lib import dispatch_payload as dp
from recoil.pipeline._lib.dispatch_payload import AuthorPromptResult, build_dispatch_payload
from recoil.pipeline._lib.plan_loader import CanonicalShot
from recoil.pipeline.core.take import Beat
from recoil.pipeline.orchestrator.episode_runner import EpisodeRunner


FILTER_SAFETY_TRIGGER_PROMPT = (
    "She draws her bloodied katana mid-strike, the lethal warrior in her "
    "torn harness confronting the guards."
)


def _shot(shot_id: str = "EP001_SH01") -> CanonicalShot:
    return CanonicalShot(
        shot_id=shot_id,
        scene_index=1,
        sequence_id="SEQ001",
        pipeline="video",
        previs_model=None,
        video_model="seeddance-2.0",
        location_id=None,
        characters=[],
        shot_type="MS",
        duration_s=2.0,
        is_env_only=False,
        has_dialogue=False,
        aspect_ratio="9:16",
        raw={},
    )


def _payload(
    *,
    monkeypatch: pytest.MonkeyPatch,
    refs: list[str],
    manifest: dict[str, int],
    board_ref_path: str | None = None,
    segment_count: int = 2,
) -> dict:
    captured_manifests: list[dict[str, int]] = []

    def _fake_collect(*args, **kwargs):
        return list(refs), dict(manifest)

    def _fake_author(ctx, **kwargs):
        captured_manifests.append(dict(kwargs["ref_manifest"]))
        return AuthorPromptResult(
            prompt="CAST: Jade, Wren\nShot body.",
            modality=ctx.modality,
            strategy="shot_spec",
            payload_refs={},
        )

    monkeypatch.setattr(dp, "_collect_reference_images", _fake_collect)
    monkeypatch.setattr(dp, "_build_author_aware_prompt", _fake_author)

    batch = [_shot(f"EP001_SH{i:02d}") for i in range(1, segment_count + 1)]
    payload = build_dispatch_payload(
        shot=batch[0],
        project="_board_ref_attach",
        modality="r2v_multi",
        batch_shots=batch,
        episode="ep_001",
        board_ref_path=board_ref_path,
    )
    payload["_captured_ref_manifest"] = captured_manifests[-1]
    return payload


def _direct_ctx(
    *,
    project: str = "_filter_safety_test",
    prompt: str = FILTER_SAFETY_TRIGGER_PROMPT,
) -> dp.PayloadContext:
    return dp.PayloadContext(
        project=project,
        modality="video_i2v",
        shot_id="EP001_SH99",
        prompt=prompt,
        model_id="seeddance-2.0",
        duration_s=2.0,
        aspect_ratio="9:16",
        generate_audio=False,
    )


def test_board_absent_preserves_refs_and_prompt(monkeypatch: pytest.MonkeyPatch) -> None:
    payload = _payload(
        monkeypatch=monkeypatch,
        refs=["/refs/hero.png", "/refs/location.png"],
        manifest={"identity_1": 1, "scene_1": 2},
    )

    assert payload["reference_images"] == ["/refs/hero.png", "/refs/location.png"]
    assert payload["prompt"].splitlines() == ["CAST: Jade, Wren", "Shot body."]
    assert "ref_manifest" not in payload


def test_filter_safety_shadow_sets_json_roundtrippable_payload_key(
    tmp_path: Path,
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    log_path = tmp_path / "filter_safety_lint.jsonl"
    monkeypatch.setenv("RECOIL_FILTER_SAFETY", "shadow")
    monkeypatch.setattr(dp, "_FILTER_SAFETY_LOG_PATH", log_path)

    payload = dp.build_unified_payload(_direct_ctx())

    assert payload["prompt"] == FILTER_SAFETY_TRIGGER_PROMPT
    assert payload["filter_safety"]["warn"] == 1
    assert payload["filter_safety"]["info"] == 0
    json.dumps(payload)

    rows = [json.loads(line) for line in log_path.read_text().splitlines()]
    assert len(rows) == 1
    assert rows[0]["project"] == "_filter_safety_test"
    assert rows[0]["shot_id"] == "EP001_SH99"
    assert rows[0]["model"] == "seeddance-2.0"
    assert rows[0]["mode"] == "shadow"
    assert rows[0]["warn_count"] == 1
    assert rows[0]["info_count"] == 0
    assert rows[0]["findings"]


def test_filter_safety_off_sets_no_payload_key_or_log(
    tmp_path: Path,
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    log_path = tmp_path / "filter_safety_lint.jsonl"
    monkeypatch.setenv("RECOIL_FILTER_SAFETY", "off")
    monkeypatch.setattr(dp, "_FILTER_SAFETY_LOG_PATH", log_path)

    payload = dp.build_unified_payload(_direct_ctx())

    assert "filter_safety" not in payload
    assert not log_path.exists()


def test_filter_safety_lint_error_does_not_break_payload_and_logs_envelope(
    tmp_path: Path,
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    log_path = tmp_path / "filter_safety_lint.jsonl"
    monkeypatch.setenv("RECOIL_FILTER_SAFETY", "shadow")
    monkeypatch.setattr(dp, "_FILTER_SAFETY_LOG_PATH", log_path)

    def _raise_lint(_prompt: str):
        raise RuntimeError("lint crashed")

    monkeypatch.setattr(dp, "lint_prompt", _raise_lint)

    payload = dp.build_unified_payload(_direct_ctx())

    assert payload["prompt"] == FILTER_SAFETY_TRIGGER_PROMPT
    assert "filter_safety" not in payload
    rows = [json.loads(line) for line in log_path.read_text().splitlines()]
    assert len(rows) == 1
    row = rows[0]
    assert row["lint_error"] == "lint crashed"
    assert row["project"] == "_filter_safety_test"
    assert row["shot_id"] == "EP001_SH99"
    assert row["model"] == "seeddance-2.0"
    assert row["mode"] == "unknown-or-shadow"
    assert row["ts"]


def test_dispatch_receipt_payload_keys_include_filter_safety(
    tmp_path: Path,
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    from recoil.pipeline.core.dispatch import dispatch
    from recoil.pipeline.core.dispatch_context import DispatchContext
    from recoil.pipeline.core.registry import RunResult, register_runner

    class _StubRunner:
        modality = "filter_safety_stub"

        def run(self, payload: dict) -> RunResult:
            return RunResult(
                id="filter_safety_stub_result",
                modality=self.modality,
                output_path="/tmp/filter-safety-stub.png",
                metadata={"payload_keys": sorted(payload)},
                success=True,
            )

    log_path = tmp_path / "filter_safety_lint.jsonl"
    monkeypatch.setenv("RECOIL_FILTER_SAFETY", "shadow")
    monkeypatch.setattr(dp, "_FILTER_SAFETY_LOG_PATH", log_path)
    register_runner("filter_safety_stub", _StubRunner())
    payload = dp.build_unified_payload(_direct_ctx())

    receipt = dispatch(
        "filter_safety_stub",
        payload,
        context=DispatchContext(
            caller_id="filter_safety_test",
            step_runner=SimpleNamespace(),
            project="_filter_safety_test",
            episode=1,
            receipts_log_path="DISABLED",
        ),
    )

    assert "filter_safety" in receipt.provenance["payload_keys"]


def test_approved_board_rides_last_and_inserts_board_line(
    tmp_path: Path,
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    board = tmp_path / "EP001_CONT_004_v01.png"
    board.write_bytes(b"board")

    payload = _payload(
        monkeypatch=monkeypatch,
        refs=["/refs/hero.png", "/refs/location.png"],
        manifest={"identity_1": 1, "scene_1": 2},
        board_ref_path=str(board),
        segment_count=3,
    )

    assert payload["reference_images"] == [
        "/refs/hero.png",
        "/refs/location.png",
        str(board),
    ]
    assert payload["ref_manifest"]["board_1"] == 3
    assert payload["_captured_ref_manifest"]["board_1"] == 3
    lines = payload["prompt"].splitlines()
    assert lines[0] == "CAST: Jade, Wren"
    assert lines[1] == (
        "The attached storyboard @Image3 defines the framing and composition "
        "of the 3 shots, panels 1-3 in order — match each shot to its panel."
    )
    assert lines[2] == "Shot body."


def test_board_survives_nine_ref_pressure(
    tmp_path: Path,
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    board = tmp_path / "EP001_CONT_004_v01.png"
    board.write_bytes(b"board")
    refs = [f"/refs/ref_{i}.png" for i in range(1, 10)]

    payload = _payload(
        monkeypatch=monkeypatch,
        refs=refs,
        manifest={"identity_1": 1, "scene_1": 9},
        board_ref_path=str(board),
        segment_count=2,
    )

    assert len(payload["reference_images"]) <= 9
    assert payload["reference_images"] == refs[:8] + [str(board)]
    assert payload["ref_manifest"]["board_1"] == 9
    assert "scene_1" not in payload["ref_manifest"]
    assert "@Image9" in payload["prompt"].splitlines()[1]


def test_unapproved_board_not_threaded_through_episode_runner(
    tmp_path: Path,
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    (tmp_path / "_board_ref_attach").mkdir()
    # REC-231 Phase 5: _build_workflow_for_beat now reads the scene-version manifest
    # (load_manifest) to stamp scene_version provenance, so the data-root sentinel must
    # be present (production always has it; this fixture previously skipped it).
    (tmp_path / ".recoil-data-root").touch()
    monkeypatch.setenv("RECOIL_PROJECTS_ROOT", str(tmp_path))

    captured: dict[str, object] = {}

    def _fake_build_dispatch_payload(**kwargs):
        captured.update(kwargs)
        return {
            "shot_id": kwargs["shot"].shot_id,
            "prompt": "CAST: Jade\nShot body.",
            "model": "seeddance-2.0",
            "duration": 2,
            "aspect_ratio": "9:16",
            "generate_audio": False,
        }

    monkeypatch.setattr(dp, "build_dispatch_payload", _fake_build_dispatch_payload)

    runner = EpisodeRunner(
        project="_board_ref_attach",
        plan={"sequences": {}},
        casting={},
        episode="ep_001",
    )
    shot = _shot()
    beat = Beat(
        beat_id="EP001_CONT_004",
        beat_metadata={
            "modality": "r2v_multi",
            "shot": dataclasses.asdict(shot),
            "batch_shots": [dataclasses.asdict(shot)],
            "scene_id": "BATCH_004",
        },
    )
    beat.set_board_proposed(
        "prep/ep_001/storyboards/EP001_CONT_004_v01.png",
        "a" * 64,
    )

    runner._build_workflow_for_beat(beat, take_index=0, beat_index=0)

    assert captured["board_ref_path"] is None


def test_audit_dry_run_payload_carries_board_ref(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
    """The dry-run/audit payload feeds revalidate_succeeded_fingerprints — an
    approved board must land in its reference_images exactly like the live
    path, or approve-then-rerun never demotes succeeded takes."""
    def _fake_collect(*args, **kwargs):
        return ["/refs/hero.png"], {"identity_1": 1}

    monkeypatch.setattr(dp, "_collect_reference_images", _fake_collect)
    board = tmp_path / "B_v01.png"
    board.write_bytes(b"png-bytes")

    batch = [_shot("EP001_SH01"), _shot("EP001_SH02")]
    payload = build_dispatch_payload(
        shot=batch[0],
        project="_board_ref_attach",
        modality="r2v_multi",
        batch_shots=batch,
        episode="ep_001",
        dry_run=True,
        skip_author=True,
        board_ref_path=str(board),
    )
    refs = payload["reference_images"]
    assert refs[-1] == str(board)
    assert refs[0] == "/refs/hero.png"


def test_revalidate_passes_approved_board_into_fingerprint_payload(
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    """Contract: revalidate_succeeded_fingerprints threads the approved board
    artifact into its dry-run payload build, so approval alone changes the
    fingerprint and demotes the succeeded take for re-dispatch."""
    from types import SimpleNamespace

    from recoil.pipeline.core.take import Scene

    captured: dict = {}

    def _fake_build(**kwargs):
        captured.update(kwargs)
        return {"reference_images": ["/refs/hero.png"]}

    monkeypatch.setattr(dp, "build_dispatch_payload", _fake_build)

    beat = Beat(
        beat_id="BATCH_004",
        beat_metadata={
            "modality": "r2v_multi",
            "scene_id": "BATCH_004",
            "shot": dataclasses.asdict(_shot("EP001_SH01")),
            "batch_shots": [dataclasses.asdict(_shot("EP001_SH01"))],
            "grouping": {"strategy": "continuity", "ordinal": 4, "shot_ids": []},
            "inputs_fingerprint": "stale-fp",
        },
    )
    beat.set_board_proposed(
        artifact="prep/ep_001/storyboards/EP001_CONT_004_v01.png",
        source_sha256="x",
    )
    beat.approve_board(approved_by="test")
    fake_take = SimpleNamespace(
        take_id="t1", take_index=1, status="succeeded", take_metadata={}
    )
    beat.takes.append(fake_take)
    beat.primary_take_id = "t1"
    scene = Scene(scene_id="BATCH_004", beats=[beat], scene_metadata={})

    runner = EpisodeRunner.__new__(EpisodeRunner)
    runner.project = "_board_ref_attach"
    runner.episode = "ep_001"
    runner.model_override = None
    runner.tier_override = None
    runner.generate_audio = None

    # Board approval changes the resolved refs -> fingerprint drift detected
    # (mutate=False surfaces drift as RerollPreflightError instead of demoting).
    from recoil.pipeline.orchestrator.episode_runner import RerollPreflightError

    with pytest.raises(RerollPreflightError):
        runner.revalidate_succeeded_fingerprints(scene, mutate=False)
    assert (
        captured.get("board_ref_path")
        == "prep/ep_001/storyboards/EP001_CONT_004_v01.png"
    )
