from __future__ import annotations

import json

import pytest

from recoil.execution.step_types import ProjectPaths
from recoil.pipeline._lib.grouping import GroupingContext, get_grouping
from recoil.pipeline.cli import generate
from recoil.pipeline.core.persistence import scene_path
from recoil.pipeline.core import persistence
from recoil.pipeline.orchestrator.coverage_planner import CoveragePass
from recoil.pipeline.orchestrator.episode_runner import EpisodeRunner
from recoil.pipeline.orchestrator.tests import test_reroll_new_take as reroll_fixture


@pytest.fixture(autouse=True)
def _reset_module_caches():
    from recoil.pipeline.core.dispatch import _reset_bootstrap_for_tests
    from recoil.pipeline.core.registry import _reset_for_tests

    _reset_for_tests()
    _reset_bootstrap_for_tests()
    yield
    _reset_for_tests()
    _reset_bootstrap_for_tests()


def _configure_fixture_project(tmp_path, monkeypatch) -> None:
    (tmp_path / ".recoil-data-root").touch()
    (tmp_path / "fixture").mkdir()
    monkeypatch.setenv("RECOIL_PROJECTS_ROOT", str(tmp_path))


def _pass_011_disk_and_rebuild(tmp_path, monkeypatch) -> tuple[dict, dict]:
    _configure_fixture_project(tmp_path, monkeypatch)
    reroll_fixture._write_cli_project()
    video_dir = tmp_path / "video"
    beat = reroll_fixture._succeeded_beat(video_dir)
    reroll_fixture._persist_pass_scene(beat)

    paths = ProjectPaths.for_episode("fixture", 1)
    passes = json.loads((paths.coverage_passes_dir / "ep_001_passes.json").read_text())
    selected_dicts = [d for d in passes if d["pass_id"] == "PASS_011"]
    canonical_plan = generate._load_selected_canonical_plan(
        paths=paths,
        episode_str="ep_001",
        pass_ids=["PASS_011"],
        selected_dicts=selected_dicts,
        force_new_take=True,
    )
    selected_passes = [CoveragePass.from_dict(d) for d in selected_dicts]
    runner = EpisodeRunner(project="fixture", plan=canonical_plan.raw, episode="ep_001")
    ctx = GroupingContext(
        project="fixture",
        episode=1,
        canonical_plan=canonical_plan,
        selected_coverage_passes=selected_passes,
        tier_map={},
        wildcard_override=None,
    )
    group = get_grouping("coverage").assemble(canonical_plan.shots, ctx)[0]
    rebuild = runner._scene_from_group(group).to_dict()
    disk = json.loads(scene_path("fixture", "ep_001", "PASS_011").read_text())
    disk.pop("schema_version", None)
    return disk, rebuild


def _divergent_shot_keys(disk: dict, rebuild: dict) -> set[str]:
    divergent: set[str] = set()
    for disk_beat, rebuild_beat in zip(disk["beats"], rebuild["beats"], strict=True):
        disk_md = disk_beat["beat_metadata"]
        rebuild_md = rebuild_beat["beat_metadata"]
        for key in ("shot",):
            disk_shot = disk_md.get(key) or {}
            rebuild_shot = rebuild_md.get(key) or {}
            all_keys = set(disk_shot) | set(rebuild_shot)
            divergent.update(
                k for k in all_keys if disk_shot.get(k) != rebuild_shot.get(k)
            )
        for disk_shot, rebuild_shot in zip(
            disk_md.get("batch_shots") or [],
            rebuild_md.get("batch_shots") or [],
            strict=True,
        ):
            all_keys = set(disk_shot) | set(rebuild_shot)
            divergent.update(
                k for k in all_keys if disk_shot.get(k) != rebuild_shot.get(k)
            )
    return divergent


def test_builder_variant_set_measured_from_live_reroll_fixture(tmp_path, monkeypatch):
    disk, rebuild = _pass_011_disk_and_rebuild(tmp_path, monkeypatch)

    measured = _divergent_shot_keys(disk, rebuild)
    print(f"measured_builder_variant_shot_keys={sorted(measured)}")

    assert measured == {"is_env_only", "aspect_ratio", "raw"}


def test_canonical_shot_identity_ignores_only_measured_builder_variants(
    tmp_path,
    monkeypatch,
):
    disk, rebuild = _pass_011_disk_and_rebuild(tmp_path, monkeypatch)
    disk_shot = disk["beats"][0]["beat_metadata"]["shot"]
    rebuild_shot = rebuild["beats"][0]["beat_metadata"]["shot"]

    assert persistence.canonical_shot_identity(disk_shot) == (
        persistence.canonical_shot_identity(rebuild_shot)
    )
    assert persistence.structural_mutations(disk, rebuild) == []


def test_structural_mutations_fail_closed_on_real_shot_change(tmp_path, monkeypatch):
    disk, rebuild = _pass_011_disk_and_rebuild(tmp_path, monkeypatch)

    duration_changed = json.loads(json.dumps(rebuild))
    duration_changed["beats"][0]["beat_metadata"]["shot"]["duration_s"] = 99
    assert persistence.structural_mutations(disk, duration_changed)

    shot_id_changed = json.loads(json.dumps(rebuild))
    shot_id_changed["beats"][0]["beat_metadata"]["batch_shots"][0]["shot_id"] = "EP001_SH99"
    assert persistence.structural_mutations(disk, shot_id_changed)


def test_structural_mutations_fail_closed_on_topology_and_scene_id(tmp_path, monkeypatch):
    disk, rebuild = _pass_011_disk_and_rebuild(tmp_path, monkeypatch)

    scene_changed = json.loads(json.dumps(rebuild))
    scene_changed["scene_id"] = "PASS_999"
    assert persistence.structural_mutations(disk, scene_changed)

    topology_changed = json.loads(json.dumps(rebuild))
    topology_changed["beats"].append(json.loads(json.dumps(topology_changed["beats"][0])))
    topology_changed["beats"][1]["beat_id"] = "PASS_011__extra"
    assert persistence.structural_mutations(disk, topology_changed)
