from __future__ import annotations

import asyncio
import hashlib
from pathlib import Path

import pytest

from recoil.pipeline._lib import derivation_manifest
from recoil.pipeline._lib.grouping import GroupingContext, get_grouping
from recoil.pipeline._lib.plan_loader import CanonicalPlan, CanonicalShot
from recoil.pipeline.core.persistence import (
    load_manifest, save_scene, scene_path, scene_version_path,
)
from recoil.pipeline.core.take import Beat, Scene
from recoil.pipeline.orchestrator.episode_runner import EpisodeRunner


PROJECT = "fixture"
EPISODE = "ep_001"
EPISODE_NUM = 1


@pytest.fixture(autouse=True)
def _projects_root(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path:
    root = tmp_path / "projects"
    root.mkdir()
    (root / ".recoil-data-root").touch()
    project_root = root / PROJECT
    project_root.mkdir()
    monkeypatch.setenv("RECOIL_PROJECTS_ROOT", str(root))
    monkeypatch.setenv("RECOIL_BOARD_GATE", "0")
    return project_root


def _shot(batch_index: int, shot_index: int) -> CanonicalShot:
    ordinal = (batch_index - 1) * 3 + shot_index
    raw = {
        "shot_id": f"EP001_SH{ordinal:02d}",
        "scene_index": ordinal,
        "pipeline": "video",
        "video_model": "seeddance-2.0",
        "aspect_ratio": "9:16",
        "asset_data": {"location_id": f"LOC_{batch_index}", "characters": []},
        "prompt_data": {"shot_type": "MS"},
        "routing_data": {
            "target_editorial_duration_s": 2.0,
            "is_env_only": True,
            "has_dialogue": False,
        },
        "source_text_hash": f"hash-{ordinal}",
    }
    return CanonicalShot(
        shot_id=raw["shot_id"],
        scene_index=ordinal,
        sequence_id=None,
        pipeline="video",
        previs_model=None,
        video_model="seeddance-2.0",
        location_id=f"LOC_{batch_index}",
        characters=[],
        shot_type="MS",
        duration_s=2.0,
        is_env_only=True,
        has_dialogue=False,
        aspect_ratio="9:16",
        raw=raw,
    )


def _plan(group_count: int) -> CanonicalPlan:
    shots = [
        _shot(batch_index, shot_index)
        for batch_index in range(1, group_count + 1)
        for shot_index in range(1, 4)
    ]
    raw_shots = [dict(shot.raw) for shot in shots]
    return CanonicalPlan(
        episode_id=EPISODE,
        project=PROJECT,
        shots=shots,
        source_path=Path("ep_001_plan.json"),
        raw={
            "episode_id": EPISODE,
            "project": PROJECT,
            "total_shots": len(shots),
            "shots": raw_shots,
            "sequences": {
                f"BATCH_{index:03d}": {
                    "shots": [
                        {"shot_id": f"INIT_{index:03d}_{shot_index:02d}"}
                        for shot_index in range(1, 3)
                    ]
                }
                for index in range(1, group_count + 1)
            },
        },
    )


def _runner(plan: CanonicalPlan) -> EpisodeRunner:
    return EpisodeRunner(
        project=PROJECT,
        plan=plan.raw,
        casting={},
        episode=EPISODE,
        max_takes=9,
        concurrency=1,
    )


def _stale_scene(scene_id: str) -> Scene:
    return Scene(
        scene_id=scene_id,
        beats=[
            Beat(
                beat_id=f"OLD_{scene_id}",
                beat_metadata={"marker": "stale"},
                max_takes=1,
            )
        ],
        scene_metadata={"marker": "stale"},
    )


def _write_scene(scene_id: str) -> Path:
    path = scene_path(PROJECT, EPISODE, scene_id)
    save_scene(_stale_scene(scene_id), path)
    return path


def _sha256(path: Path) -> str:
    return hashlib.sha256(path.read_bytes()).hexdigest()


def _stamp_scenes_stage() -> dict:
    stage = {
        "kind": "derived",
        "content_sha": "sha256:before-content",
        "structural_sha": None,
        "source": {"plan_structural_sha": "sha256:before-plan"},
        "built_at": "2026-06-21T00:00:00Z",
        "builder": "test",
        "model": None,
        "via": None,
        "scene_ids": ["BATCH_001", "BATCH_002", "BATCH_003"],
        "shot_script_spans": {"BATCH_001": {"EP001_SH01": "old"}},
    }
    manifest = derivation_manifest.load(PROJECT, EPISODE_NUM)
    manifest["stages"]["scenes"] = dict(stage)
    derivation_manifest.save(PROJECT, EPISODE_NUM, manifest)
    return dict(stage)


def _full_group_ids(plan: CanonicalPlan) -> list[str]:
    ctx = GroupingContext(
        project=PROJECT,
        episode=EPISODE_NUM,
        canonical_plan=plan,
        selected_coverage_passes=[],
        tier_map={},
        wildcard_override=None,
    )
    return [
        group.scene_id
        for group in get_grouping("continuity").assemble(list(plan.shots), ctx)
    ]


def _trap_run_scene(calls: list[str]):
    async def _run_scene(scene, **kwargs):  # noqa: ANN001, ANN003
        calls.append(scene.scene_id)
        raise AssertionError("run_scene must not be called in derive_only mode")

    return _run_scene


def test_only_scene_ids_derive_only_writes_target_preserves_siblings_and_manifest(
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    plan = _plan(group_count=3)
    group_ids = _full_group_ids(plan)
    assert group_ids == ["BATCH_001", "BATCH_002", "BATCH_003"]
    target_id = group_ids[0]

    target_path = _write_scene(target_id)
    sibling_path = _write_scene("BATCH_002")
    missing_sibling_path = scene_path(PROJECT, EPISODE, "BATCH_003")
    assert not missing_sibling_path.exists()
    target_before = _sha256(target_path)
    sibling_before = _sha256(sibling_path)
    scenes_stage_before = _stamp_scenes_stage()

    runner = _runner(plan)
    run_scene_calls: list[str] = []
    scene_from_group_calls: list[str] = []
    original_scene_from_group = runner._scene_from_group

    def _record_scene_from_group(group):  # noqa: ANN001
        scene_from_group_calls.append(group.scene_id)
        return original_scene_from_group(group)

    monkeypatch.setattr(runner, "_scene_from_group", _record_scene_from_group)
    monkeypatch.setattr(runner, "run_scene", _trap_run_scene(run_scene_calls))

    result = asyncio.run(
        runner.run_episode_batches(
            plan,
            derive_only=True,
            only_scene_ids={target_id},
        )
    )

    assert result["derive_only"] is True
    assert result["written"] == [target_id]
    assert result["skipped"] == []
    assert run_scene_calls == []
    assert scene_from_group_calls == [target_id]
    # REC-231 Phase 2: the target's flat body (v1) is byte-PRESERVED; the re-derived
    # structure is appended as a v2 candidate and the pointer never moves. Siblings,
    # which were not selected, stay untouched exactly as before.
    assert _sha256(target_path) == target_before
    assert load_manifest(PROJECT, EPISODE, target_id)["active_version"] == 1
    assert scene_version_path(PROJECT, EPISODE, target_id, 2).exists()
    assert _sha256(sibling_path) == sibling_before
    assert not missing_sibling_path.exists()
    assert (
        derivation_manifest.load(PROJECT, EPISODE_NUM)["stages"]["scenes"]
        == scenes_stage_before
    )


def test_only_scene_ids_single_scene_episode_does_not_stamp_manifest(
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    plan = _plan(group_count=1)
    group_ids = _full_group_ids(plan)
    assert group_ids == ["BATCH_001"]
    target_id = group_ids[0]
    _write_scene(target_id)
    scenes_stage_before = _stamp_scenes_stage()

    runner = _runner(plan)
    run_scene_calls: list[str] = []
    monkeypatch.setattr(runner, "run_scene", _trap_run_scene(run_scene_calls))

    result = asyncio.run(
        runner.run_episode_batches(
            plan,
            derive_only=True,
            only_scene_ids={target_id},
        )
    )

    assert result["written"] == [target_id]
    assert run_scene_calls == []
    assert (
        derivation_manifest.load(PROJECT, EPISODE_NUM)["stages"]["scenes"]
        == scenes_stage_before
    )
