"""POST /reroll route — REC-111 (consult decision #17).

Mounts on the REAL app object; monkeypatches the runner seam with a fake that
records calls. Covers: contract shape, dry_run estimate with zero dispatches,
invalid batch_id → 422, batch_not_single_beat, and strategy override reaching
the fake.
"""
from __future__ import annotations

import pytest
from fastapi.testclient import TestClient

# Importing the REAL app wires up the api.* namespace + registers /reroll.
from recoil.pipeline.api.main import app
import recoil.pipeline.api.routes.reroll as reroll_mod
from recoil.pipeline.core.persistence import (
    SceneVersionConflictError,
    save_scene,
    scene_path,
)
from recoil.pipeline.core.take import Beat, Scene


@pytest.fixture(autouse=True)
def _isolate(tmp_path, monkeypatch):
    root = tmp_path / "projects"
    root.mkdir()
    (root / ".recoil-data-root").touch()
    (root / "fixture").mkdir()
    monkeypatch.setenv("RECOIL_PROJECTS_ROOT", str(root))
    yield


class _FakeRunner:
    """Records the runner-seam calls the route makes."""

    def __init__(self, take_status="succeeded", conflict_on=None):
        self.estimate_calls = []
        self.prepare_calls = []
        self.run_scene_calls = []
        self.take_status = take_status
        self.conflict_on = conflict_on

    def _estimate_take_cost(self, beat):
        self.estimate_calls.append(beat)
        return 2.5

    def prepare_beat_for_reroll(self, scene, beat, *, expected_version=None):
        self.prepare_calls.append((scene, beat))
        return {
            "beat_id": beat.beat_id,
            "cleared_stale_primary": None,
            "next_take_index": len(beat.takes),
        }

    async def run_scene(self, scene, **kwargs):
        self.run_scene_calls.append(kwargs)
        if self.conflict_on == "run_scene":
            raise SceneVersionConflictError("BATCH_004", 2, 3)
        from types import SimpleNamespace

        beat = scene.beats[0]
        beat.takes.append(
            SimpleNamespace(take_index=len(beat.takes), status=self.take_status)
        )
        return scene


def _save_batch_scene(
    project="fixture",
    episode_token="ep_001",
    scene_id="BATCH_004",
    *,
    strategy="continuity",
    ordinal=4,
    n_beats=1,
    modality="r2v_multi",
):
    grouping = {
        "strategy": strategy,
        "ordinal": ordinal,
        "shot_ids": [],
        "source_pass_id": None,
    }
    beats = [
        Beat(
            beat_id=(f"{scene_id}__{i}" if n_beats > 1 else scene_id),
            max_takes=5,
            beat_metadata={
                "modality": modality,
                "grouping": dict(grouping),
                "scene_id": scene_id,
            },
        )
        for i in range(n_beats)
    ]
    scene = Scene(
        scene_id=scene_id,
        beats=beats,
        scene_metadata={"grouping": dict(grouping)},
    )
    save_scene(scene, scene_path(project, episode_token, scene_id))


def _install_fake(
    monkeypatch, take_status="succeeded", conflict_on=None
) -> _FakeRunner:
    fake = _FakeRunner(take_status=take_status, conflict_on=conflict_on)
    monkeypatch.setattr(
        reroll_mod, "_runner_for_reroll", lambda project, episode: fake
    )
    return fake


def test_reroll_route_mounted_on_real_app():
    assert any(getattr(r, "path", None) == "/reroll" for r in app.routes)


def test_dry_run_returns_estimate_with_zero_dispatches(monkeypatch):
    _save_batch_scene()
    fake = _install_fake(monkeypatch)
    client = TestClient(app)

    resp = client.post(
        "/reroll",
        json={
            "project": "fixture",
            "episode": 1,
            "batch_id": "EP001_CONT_004",
            "dry_run": True,
        },
    )

    assert resp.status_code == 200
    body = resp.json()
    assert body == {"dispatched": [], "budget_estimate_usd": 2.5}
    # ZERO state writes: prepare + dispatch never run on a dry_run.
    assert fake.prepare_calls == []
    assert fake.run_scene_calls == []


def test_live_dispatch_contract_shape_and_strategy_override(monkeypatch):
    _save_batch_scene()
    fake = _install_fake(monkeypatch)
    client = TestClient(app)

    resp = client.post(
        "/reroll",
        json={
            "project": "fixture",
            "episode": 1,
            "batch_id": "EP001_CONT_004",
            "strategy": "shot_spec",
            "note": "tighten the push-in",
        },
    )

    assert resp.status_code == 200
    body = resp.json()
    assert body["budget_estimate_usd"] == 2.5
    assert body["dispatched"] == [
        {
            "beat_id": "BATCH_004",
            "take_number": 1,
            "batch_file": "ep_001_BATCH_004.json",
        }
    ]
    # Strategy override + note reach the runner seam as explicit kwargs.
    assert len(fake.run_scene_calls) == 1
    call = fake.run_scene_calls[0]
    assert call["strategy_override"] == "shot_spec"
    assert call["reroll_note"] == "tighten the push-in"
    assert call["reroll_beat_id"] == "BATCH_004"
    assert call["force_new_take"] is True
    assert len(fake.prepare_calls) == 1


def test_invalid_batch_id_returns_422(monkeypatch):
    _install_fake(monkeypatch)
    client = TestClient(app)

    resp = client.post(
        "/reroll",
        json={"project": "fixture", "episode": 1, "batch_id": "not-a-selector"},
    )

    assert resp.status_code == 422
    assert resp.json()["error"] == "invalid_batch_selector"


def test_missing_project_or_episode_returns_422(monkeypatch):
    _install_fake(monkeypatch)
    client = TestClient(app)

    assert (
        client.post("/reroll", json={"episode": 1, "batch_id": "EP001_CONT_004"}).status_code
        == 422
    )
    assert (
        client.post(
            "/reroll", json={"project": "fixture", "batch_id": "EP001_CONT_004"}
        ).status_code
        == 422
    )


def test_batch_not_single_beat_returns_422(monkeypatch):
    _save_batch_scene(n_beats=2)
    _install_fake(monkeypatch)
    client = TestClient(app)

    resp = client.post(
        "/reroll",
        json={"project": "fixture", "episode": 1, "batch_id": "EP001_CONT_004"},
    )

    assert resp.status_code == 422
    assert resp.json()["error"] == "batch_not_single_beat"


def test_metadata_mismatch_returns_422(monkeypatch):
    # Persisted grouping ordinal (5) disagrees with the selector ordinal (4).
    _save_batch_scene(ordinal=5)
    _install_fake(monkeypatch)
    client = TestClient(app)

    resp = client.post(
        "/reroll",
        json={"project": "fixture", "episode": 1, "batch_id": "EP001_CONT_004"},
    )

    assert resp.status_code == 422
    assert resp.json()["error"] == "batch_selector_metadata_mismatch"


def test_missing_scene_returns_404(monkeypatch):
    _install_fake(monkeypatch)
    client = TestClient(app)

    resp = client.post(
        "/reroll",
        json={"project": "fixture", "episode": 1, "batch_id": "EP001_ONER_009"},
    )

    assert resp.status_code == 404
    assert resp.json()["error"] == "batch_scene_missing"



def test_unknown_strategy_rejected_422_no_dispatch(monkeypatch):
    """Merge-gate r2: unknown strategy fails closed BEFORE any state mutation."""
    _save_batch_scene()
    fake = _install_fake(monkeypatch)
    client = TestClient(app)
    resp = client.post(
        "/reroll",
        json={
            "project": "fixture",
            "episode": 1,
            "batch_id": "EP001_CONT_004",
            "strategy": "not_a_real_strategy",
            "dry_run": False,
        },
    )
    assert resp.status_code == 422
    assert resp.json()["error"] == "unknown_author_strategy"
    assert fake.run_scene_calls == []


def test_non_r2v_beat_rejected_before_any_action(monkeypatch):
    """Merge-gate r7: a non-r2v_multi beat must 422 before dry-run or prepare."""
    _save_batch_scene(modality="video_i2v")
    fake = _install_fake(monkeypatch)
    client = TestClient(app)
    resp = client.post("/reroll", json={
        "project": "fixture", "episode": 1,
        "batch_id": "EP001_CONT_004", "dry_run": True,
    })
    assert resp.status_code == 422
    assert resp.json()["error"] == "batch_not_single_beat"
    assert fake.run_scene_calls == []


def test_failed_dispatch_reported_as_502(monkeypatch):
    """Merge-gate r9: a failed new take must NOT return a 200 dispatched response."""
    _save_batch_scene()
    fake = _install_fake(monkeypatch, take_status="failed")
    client = TestClient(app)
    resp = client.post("/reroll", json={
        "project": "fixture", "episode": 1,
        "batch_id": "EP001_CONT_004", "dry_run": False,
    })
    assert resp.status_code == 502
    assert resp.json()["error"] == "dispatch_failed"


def test_scene_version_conflict_returns_409_structured_body(monkeypatch):
    _save_batch_scene()
    _install_fake(monkeypatch, conflict_on="run_scene")
    client = TestClient(app, raise_server_exceptions=False)

    resp = client.post("/reroll", json={
        "project": "fixture", "episode": 1,
        "batch_id": "EP001_CONT_004", "dry_run": False,
    })

    assert resp.status_code == 409
    body = resp.json()
    assert body["success"] is False
    assert body["error"] == "scene_version_conflict"
    assert body["batch_id"] == "BATCH_004"
    assert body["expected_version"] == 2
    assert body["current_version"] == 3
