from __future__ import annotations

import json

import pytest

from recoil.pipeline._lib.schema_versions import SCENE_VERSIONS_MANIFEST_SCHEMA_VERSION
from recoil.pipeline.core.persistence import (
    SceneStructureImmutableError,
    load_manifest,
    load_scene,
    save_active_scene,
    save_scene,
    scene_manifest_path,
    scene_path,
    scene_version_path,
    structure_fingerprint,
)
from recoil.pipeline.core.take import Beat, Scene, Take
from recoil.pipeline.core.workflow import Workflow, WorkflowStep


def _configure_project(tmp_path, monkeypatch) -> None:
    (tmp_path / ".recoil-data-root").touch()
    (tmp_path / "fixture").mkdir()
    monkeypatch.setenv("RECOIL_PROJECTS_ROOT", str(tmp_path))


def _shot(**overrides):
    data = {
        "shot_id": "EP001_SH01",
        "scene_index": 1,
        "sequence_id": None,
        "pipeline": "video",
        "previs_model": None,
        "video_model": "seeddance-2.0",
        "location_id": "LOC_A",
        "characters": [],
        "shot_type": "MS",
        "duration_s": 2.0,
        "is_env_only": True,
        "has_dialogue": False,
        "aspect_ratio": "9:16",
        "raw": {"shot_id": "EP001_SH01", "description": "disk"},
        "cinematography": None,
        "quality": None,
    }
    data.update(overrides)
    return data


def _scene(
    *,
    shot=None,
    batch_shots=None,
    board=None,
    beat_metadata_extra=None,
    takes=None,
    primary_take_id=None,
    phantom_recovery_count=0,
) -> Scene:
    beat_metadata = {
        "scene_id": "BATCH_001",
        "shot": shot or _shot(),
        "batch_shots": batch_shots or [shot or _shot()],
        "modality": "r2v_multi",
        "grouping": {
            "strategy": "coverage",
            "ordinal": 1,
            "shot_ids": ["EP001_SH01"],
            "source_pass_id": "PASS_001",
            "shotset_hash": "abc",
        },
    }
    beat_metadata.update(beat_metadata_extra or {})
    return Scene(
        scene_id="BATCH_001",
        beats=[
            Beat(
                "BATCH_001__cov",
                beat_metadata=beat_metadata,
                board=board,
                takes=list(takes or []),
                primary_take_id=primary_take_id,
                phantom_recovery_count=phantom_recovery_count,
            )
        ],
    )


def _wf() -> Workflow:
    return Workflow(
        workflow_id="wf",
        steps=[WorkflowStep(step_id="video", modality="video_i2v", payload={})],
    )


def test_writer1_preserves_concurrent_board_and_overlays_runner_status(
    tmp_path,
    monkeypatch,
):
    _configure_project(tmp_path, monkeypatch)
    board = {
        "status": "approved",
        "artifact": "prep/ep_001/storyboards/BATCH_001.png",
        "source_sha256": "a" * 64,
        "approved_by": "jt",
        "updated_at": "2026-06-22T00:00:00Z",
    }
    save_scene(_scene(board=board), scene_path("fixture", "ep_001", "BATCH_001"))

    runner_scene = _scene(
        board=None,
        takes=[Take("BATCH_001__cov_take_0", 0, _wf(), status="succeeded")],
        primary_take_id="BATCH_001__cov_take_0",
        phantom_recovery_count=1,
        beat_metadata_extra={"inputs_fingerprint": "fresh"},
    )
    save_active_scene(
        "fixture", "ep_001", "BATCH_001", runner_scene, expected_version=1
    )

    body = load_scene(scene_path("fixture", "ep_001", "BATCH_001"))
    beat = body.beats[0]
    assert beat.board == board
    assert len(beat.takes) == 1
    assert beat.primary_take_id == "BATCH_001__cov_take_0"
    assert beat.phantom_recovery_count == 1
    assert beat.beat_metadata["inputs_fingerprint"] == "fresh"


def test_writer1_ignores_builder_variant_shot_representation(tmp_path, monkeypatch):
    _configure_project(tmp_path, monkeypatch)
    disk_shot = _shot(is_env_only=True, aspect_ratio="9:16", raw={"description": "disk"})
    runner_shot = _shot(is_env_only=False, aspect_ratio=None, raw={"description": "runner"})
    save_scene(_scene(shot=disk_shot), scene_path("fixture", "ep_001", "BATCH_001"))

    save_active_scene(
        "fixture", "ep_001", "BATCH_001",
        _scene(shot=runner_shot, beat_metadata_extra={"inputs_fingerprint": "fresh"}),
        expected_version=1,
    )

    body = json.loads(scene_path("fixture", "ep_001", "BATCH_001").read_text())
    persisted_shot = body["beats"][0]["beat_metadata"]["shot"]
    assert persisted_shot["is_env_only"] is True
    assert persisted_shot["aspect_ratio"] == "9:16"
    assert persisted_shot["raw"] == {"description": "disk"}
    assert body["beats"][0]["beat_metadata"]["inputs_fingerprint"] == "fresh"


def test_writer1_enriches_absent_structural_keys_but_rejects_clobber(
    tmp_path,
    monkeypatch,
):
    _configure_project(tmp_path, monkeypatch)
    disk_scene = _scene()
    disk_scene.beats[0].beat_metadata.pop("batch_summary", None)
    save_scene(disk_scene, scene_path("fixture", "ep_001", "BATCH_001"))

    runner_scene = _scene(
        beat_metadata_extra={"batch_summary": {"shot_ids": ["EP001_SH01"]}}
    )
    save_active_scene(
        "fixture", "ep_001", "BATCH_001", runner_scene, expected_version=1
    )
    body = load_scene(scene_path("fixture", "ep_001", "BATCH_001"))
    assert body.beats[0].beat_metadata["batch_summary"] == {"shot_ids": ["EP001_SH01"]}

    clobber = _scene(shot=_shot(duration_s=9.0))
    with pytest.raises(SceneStructureImmutableError):
        save_active_scene(
            "fixture", "ep_001", "BATCH_001", clobber, expected_version=1
        )


def test_writer1_enriches_absent_structural_keys_only_for_flat_bodies(
    tmp_path,
    monkeypatch,
):
    _configure_project(tmp_path, monkeypatch)
    registered_body = _scene()
    registered_body.beats[0].beat_metadata.pop("batch_summary", None)
    registered_fp = structure_fingerprint(registered_body)
    save_scene(registered_body, scene_version_path("fixture", "ep_001", "BATCH_001", 2))
    scene_manifest_path("fixture", "ep_001", "BATCH_001").write_text(
        json.dumps({
            "schema_version": SCENE_VERSIONS_MANIFEST_SCHEMA_VERSION,
            "batch_id": "BATCH_001",
            "active_version": 2,
            "versions": [
                {
                    "version": 2,
                    "artifact": "ep_001_BATCH_001.v002.json",
                    "state": "approved",
                    "downstream": "derived",
                    "structure_fingerprint": registered_fp,
                }
            ],
        })
    )

    runner_scene = _scene(
        beat_metadata_extra={"batch_summary": {"shot_ids": ["EP001_SH01"]}}
    )
    save_active_scene(
        "fixture", "ep_001", "BATCH_001", runner_scene, expected_version=2
    )

    registered_after = load_scene(
        scene_version_path("fixture", "ep_001", "BATCH_001", 2)
    )
    assert "batch_summary" not in registered_after.beats[0].beat_metadata
    manifest_after = load_manifest("fixture", "ep_001", "BATCH_001")
    assert manifest_after["versions"][0]["structure_fingerprint"] == registered_fp
    assert structure_fingerprint(registered_after) == registered_fp

    flat_body = _scene()
    flat_body.scene_id = "BATCH_002"
    flat_body.beats[0].beat_id = "BATCH_002__cov"
    flat_body.beats[0].beat_metadata["scene_id"] = "BATCH_002"
    flat_body.beats[0].beat_metadata.pop("batch_summary", None)
    save_scene(flat_body, scene_path("fixture", "ep_001", "BATCH_002"))
    flat_runner = _scene(
        beat_metadata_extra={"batch_summary": {"shot_ids": ["EP001_SH01"]}}
    )
    flat_runner.scene_id = "BATCH_002"
    flat_runner.beats[0].beat_id = "BATCH_002__cov"
    flat_runner.beats[0].beat_metadata["scene_id"] = "BATCH_002"
    save_active_scene(
        "fixture", "ep_001", "BATCH_002", flat_runner, expected_version=1
    )

    flat_after = load_scene(scene_path("fixture", "ep_001", "BATCH_002"))
    assert flat_after.beats[0].beat_metadata["batch_summary"] == {
        "shot_ids": ["EP001_SH01"]
    }


def test_writer1_missing_registered_body_fails_closed(tmp_path, monkeypatch):
    _configure_project(tmp_path, monkeypatch)
    manifest_path = scene_manifest_path("fixture", "ep_001", "BATCH_001")
    manifest_path.parent.mkdir(parents=True, exist_ok=True)
    manifest_path.write_text(
        json.dumps({
            "schema_version": SCENE_VERSIONS_MANIFEST_SCHEMA_VERSION,
            "batch_id": "BATCH_001",
            "active_version": 2,
            "versions": [
                {
                    "version": 2,
                    "artifact": "ep_001_BATCH_001.v002.json",
                    "state": "approved",
                    "downstream": "derived",
                }
            ],
        })
    )

    with pytest.raises(FileNotFoundError):
        save_active_scene(
            "fixture", "ep_001", "BATCH_001", _scene(), expected_version=2
        )


def test_writer1_never_deserializes_inflight_takes(tmp_path, monkeypatch):
    _configure_project(tmp_path, monkeypatch)
    save_scene(_scene(), scene_path("fixture", "ep_001", "BATCH_001"))

    def _explode(_d):
        raise AssertionError("Writer 1 must not call Take.from_dict")

    monkeypatch.setattr(Take, "from_dict", _explode)
    stub_take = Take("BATCH_001__cov_take_0", 0, _wf())
    stub_take.to_dict = lambda: {}
    runner_scene = _scene(takes=[stub_take], primary_take_id=None)

    save_active_scene(
        "fixture", "ep_001", "BATCH_001", runner_scene, expected_version=1
    )

    body = json.loads(scene_path("fixture", "ep_001", "BATCH_001").read_text())
    assert body["beats"][0]["takes"] == [{}]
