"""REC-111 --batch selector resolution for generate.py.

Covers selector parsing (CONT→BATCH_NNN, ONER→ONER_NNN), invalid-format
rejection, scope mutual-exclusion, the grouping-metadata cross-check, the
legacy-file guard, and strategy_override threading into run_scene.
"""
from __future__ import annotations

import dataclasses
import json
import sys
from unittest.mock import MagicMock

import pytest

from recoil.core.paths import ProjectPaths as CoreProjectPaths
from recoil.pipeline.cli import generate
from recoil.pipeline.core.persistence import (
    SceneVersionConflictError,
    save_scene,
    scene_path,
)
from recoil.pipeline.core.scene_version_store import SceneVersionStore
from recoil.pipeline.core.take import Beat, Scene
from recoil.pipeline.core.workflow import Workflow, WorkflowStep
from recoil.pipeline.orchestrator.batch_selector import parse_batch_selector
from recoil.pipeline.orchestrator.episode_runner import EpisodeRunner
from recoil.pipeline.orchestrator.tests.test_reroll_new_take import _shot


@pytest.fixture(autouse=True)
def _isolate(tmp_path, monkeypatch):
    root = tmp_path / "projects"
    root.mkdir()
    (root / ".recoil-data-root").touch()
    monkeypatch.setenv("RECOIL_PROJECTS_ROOT", str(root))
    monkeypatch.setattr(
        "recoil.pipeline.orchestrator.episode_runner.ops_log.write",
        lambda *a, **kw: None,
    )
    yield


def _save_batch_scene(
    project="fixture",
    episode_token="ep_001",
    scene_id="BATCH_004",
    *,
    strategy="continuity",
    ordinal=4,
):
    shot_ids = ["EP001_SH23", "EP001_SH24", "EP001_SH25"]
    shots = [_shot(sid, i + 1) for i, sid in enumerate(shot_ids)]
    grouping = {
        "strategy": strategy,
        "ordinal": ordinal,
        "shot_ids": shot_ids,
        "source_pass_id": None,
    }
    beat = Beat(
        beat_id=scene_id,
        max_takes=5,
        beat_metadata={
            "scene_id": scene_id,
            "modality": "r2v_multi",
            "shot": dataclasses.asdict(shots[0]),
            "batch_shots": [dataclasses.asdict(s) for s in shots],
            "grouping": dict(grouping),
            "inputs_fingerprint": "",
        },
    )
    scene = Scene(
        scene_id=scene_id,
        beats=[beat],
        scene_metadata={
            "episode": episode_token,
            "project": project,
            "batch": True,
            "grouping": dict(grouping),
        },
    )
    save_scene(scene, scene_path(project, episode_token, scene_id))
    return scene


def _add_failed_primary(beat: Beat) -> None:
    take = beat.new_take(
        workflow=Workflow(
            workflow_id=f"{beat.beat_id}__take_0",
            steps=[
                WorkflowStep(
                    step_id="video",
                    modality="r2v_multi",
                    payload={"shot_id": beat.beat_id},
                )
            ],
            global_provenance={"shot_id": beat.beat_id},
        )
    )
    take.status = "failed"
    beat.primary_take_id = take.take_id


# ── selector parsing ────────────────────────────────────────────────


def test_parse_continuity_maps_to_batch():
    sel = parse_batch_selector("EP001_CONT_004")
    assert sel is not None
    assert sel.strategy == "continuity"
    assert sel.ordinal == 4
    assert sel.scene_id == "BATCH_004"


def test_parse_oner_maps_to_oner():
    sel = parse_batch_selector("EP002_ONER_003")
    assert sel is not None
    assert sel.strategy == "oner"
    assert sel.ordinal == 3
    assert sel.scene_id == "ONER_003"


def test_parse_invalid_returns_none():
    assert parse_batch_selector("garbage") is None
    assert parse_batch_selector("EP1_CONT_4") is None        # not zero-padded
    assert parse_batch_selector("EP001_SOLO_004") is None     # unknown strategy
    assert parse_batch_selector("") is None


# ── run_generation --batch resolution ───────────────────────────────


def test_batch_invalid_selector_rejected():
    result = generate.run_generation(
        project="fixture", episode=1, batch="garbage",
        force_new_take=True, dry_run=True,
    )
    assert result["error"] == "invalid_batch_selector"


def test_batch_dry_run_resolves_continuity():
    _save_batch_scene()
    result = generate.run_generation(
        project="fixture", episode=1, batch="EP001_CONT_004",
        force_new_take=True, dry_run=True,
    )
    assert result["success"] is True
    assert result["dry_run"] is True
    assert result["grouping"] == "continuity"
    assert "estimated_cost_usd" in result


def test_batch_dry_run_resolves_oner():
    _save_batch_scene(scene_id="ONER_002", strategy="oner", ordinal=2)
    result = generate.run_generation(
        project="fixture", episode=1, batch="EP001_ONER_002",
        force_new_take=True, dry_run=True,
    )
    assert result["success"] is True
    assert result["grouping"] == "oner"


def test_batch_scene_missing():
    result = generate.run_generation(
        project="fixture", episode=1, batch="EP001_CONT_009",
        force_new_take=True, dry_run=True,
    )
    assert result["error"] == "batch_scene_missing"


def test_batch_metadata_mismatch_raises():
    # Persisted grouping ordinal (5) disagrees with the selector ordinal (4).
    _save_batch_scene(ordinal=5)
    result = generate.run_generation(
        project="fixture", episode=1, batch="EP001_CONT_004",
        force_new_take=True, dry_run=True,
    )
    assert result["error"] == "batch_selector_metadata_mismatch"


def test_legacy_batch_files_ignored():
    # A legacy 1_BATCH_004.json sitting in the scenes dir must never be probed —
    # resolution goes through scene_path() → ep_001_BATCH_004.json only.
    _save_batch_scene()
    scenes_dir = CoreProjectPaths.for_project("fixture").orchestration_scenes_dir
    (scenes_dir / "1_BATCH_004.json").write_text("{ not valid json at all", encoding="utf-8")

    result = generate.run_generation(
        project="fixture", episode=1, batch="EP001_CONT_004",
        force_new_take=True, dry_run=True,
    )
    assert result["success"] is True
    assert result["grouping"] == "continuity"


# ── flag interplay ──────────────────────────────────────────────────


def test_unknown_strategy_rejected():
    _save_batch_scene()
    result = generate.run_generation(
        project="fixture", episode=1, batch="EP001_CONT_004",
        force_new_take=True, dry_run=True, strategy="not_a_strategy",
    )
    assert result["error"] == "unknown_author_strategy"


def test_strategy_requires_new_take():
    result = generate.run_generation(
        project="fixture", episode=1, batch="EP001_CONT_004",
        force_new_take=False, strategy="shot_spec",
    )
    assert result["error"] == "flag_requires_new_take"


def test_batch_mutual_exclusion_with_pass(monkeypatch):
    # --batch and --pass share the argparse mutually-exclusive scope group.
    monkeypatch.setattr(
        sys, "argv",
        ["generate.py", "--project", "fixture", "--episode", "1",
         "--new-take", "--batch", "EP001_CONT_004", "--pass", "PASS_011"],
    )
    with pytest.raises(SystemExit) as exc:
        generate.main()
    assert exc.value.code == 2


# ── live threading ──────────────────────────────────────────────────


def test_batch_strategy_override_reaches_run_scene(monkeypatch):
    _save_batch_scene()
    recorded: dict = {}

    async def _fake_run_scene(self, scene, **kwargs):
        recorded.update(kwargs)
        return scene

    monkeypatch.setattr(EpisodeRunner, "run_scene", _fake_run_scene)
    # run_scene is faked, so the StepRunner is never used for dispatch.
    monkeypatch.setattr(generate, "StepRunner", lambda store, paths, episode=None: MagicMock())

    generate.run_generation(
        project="fixture", episode=1, batch="EP001_CONT_004",
        force_new_take=True, strategy="shot_spec",
    )

    assert recorded["strategy_override"] == "shot_spec"
    assert recorded["force_new_take"] is True
    assert recorded["reroll_beat_id"] == "BATCH_004"
    assert recorded["reroll_note"] is None


def test_batch_new_take_returns_structured_board_gate_block_for_not_derived(
    monkeypatch, capsys
):
    scene = _save_batch_scene()
    store = SceneVersionStore("fixture", "ep_001")
    store.write_scene_candidate("BATCH_004", scene)
    store.conform("BATCH_004", 2)
    monkeypatch.setattr(generate, "StepRunner", lambda store, paths, episode=None: MagicMock())

    monkeypatch.setattr(
        sys, "argv",
        [
            "generate.py", "--project", "fixture", "--episode", "1",
            "--new-take", "--batch", "EP001_CONT_004",
        ],
    )

    exit_code = generate.main()
    result = json.loads(capsys.readouterr().out)

    assert exit_code != 0
    assert result["success"] is False
    assert result["error"] == "board_gate_blocked"
    assert result["reason"] == "active_version_not_derived"
    assert result["beat_id"] == "BATCH_004"
    assert "active scene version v2" in result["message"]


def test_batch_new_take_returns_structured_scene_version_conflict(monkeypatch):
    scene = _save_batch_scene()
    _add_failed_primary(scene.beats[0])
    save_scene(scene, scene_path("fixture", "ep_001", "BATCH_004"))
    monkeypatch.setattr(generate, "StepRunner", lambda store, paths, episode=None: MagicMock())

    def _raise_conflict(*args, **kwargs):
        raise SceneVersionConflictError("BATCH_004", 2, 3)

    monkeypatch.setattr(
        "recoil.pipeline.orchestrator.episode_runner.save_active_scene",
        _raise_conflict,
    )

    result = generate.run_generation(
        project="fixture",
        episode=1,
        batch="EP001_CONT_004",
        force_new_take=True,
    )

    assert result["success"] is False
    assert result["error"] == "scene_version_conflict"
    assert result["batch_id"] == "BATCH_004"
    assert result["expected_version"] == 2
    assert result["current_version"] == 3


def _save_video_i2v_scene(
    project="fixture", episode_token="ep_001", scene_id="BATCH_007"
):
    """A below-threshold / per-shot solo batch — modality video_i2v, NOT boardable."""
    shot = _shot("EP001_SH40", 1)
    beat = Beat(
        beat_id="EP001_SH40",
        max_takes=5,
        beat_metadata={
            "scene_id": scene_id,
            "modality": "video_i2v",
            "shot": dataclasses.asdict(shot),
            "inputs_fingerprint": "",
        },
    )
    scene = Scene(
        scene_id=scene_id,
        beats=[beat],
        scene_metadata={"episode": episode_token, "project": project, "batch": True},
    )
    save_scene(scene, scene_path(project, episode_token, scene_id))
    return scene


def test_freshness_gate_blocks_boardable_but_skips_nonboardable_video_i2v():
    """REC-231: the not_derived freshness block is BOARD-centric. A conformed r2v_multi
    candidate (boardable) blocks dispatch until re-board; a conformed video_i2v candidate
    (non-boardable, no r2v_multi beat → mark_derived can never clear it) must NOT block,
    or it would be permanently undispatchable. Both have a not_derived active version."""
    from recoil.pipeline.orchestrator.episode_runner import (
        BoardGateError,
        _preflight_scene_version_freshness,
    )

    store = SceneVersionStore("fixture", "ep_001")

    r2v_scene = _save_batch_scene(scene_id="BATCH_004")
    store.write_scene_candidate("BATCH_004", r2v_scene)
    store.conform("BATCH_004", 2)  # active v2 is not_derived

    vi2v_scene = _save_video_i2v_scene(scene_id="BATCH_007")
    store.write_scene_candidate("BATCH_007", vi2v_scene)
    store.conform("BATCH_007", 2)  # active v2 is not_derived

    r2v_beat = Beat(
        beat_id="BATCH_004", max_takes=5,
        beat_metadata={"scene_id": "BATCH_004", "modality": "r2v_multi"},
    )
    vi2v_beat = Beat(
        beat_id="EP001_SH40", max_takes=5,
        beat_metadata={"scene_id": "BATCH_007", "modality": "video_i2v"},
    )

    # boardable + not_derived → BLOCKED
    with pytest.raises(BoardGateError):
        _preflight_scene_version_freshness("fixture", 1, [r2v_beat])

    # non-boardable + not_derived → SKIPPED (no raise) — not permanently undispatchable
    _preflight_scene_version_freshness("fixture", 1, [vi2v_beat])


def test_noncanonical_selector_rejected():
    """Merge-gate r3: EP0001_CONT_0004 (4-digit fields) must be invalid."""
    from recoil.pipeline.orchestrator.batch_selector import parse_batch_selector

    for bad in ("EP0001_CONT_0004", "EP001_CONT_0004", "EP0001_CONT_004"):
        assert parse_batch_selector(bad) is None, bad


def test_retry_with_batch_still_deferred(monkeypatch):
    """Merge-gate r4: --retry must NEVER ride the --batch reroll path."""
    import recoil.pipeline.cli.generate as gen

    result = gen.run_generation(
        project="fixture", episode=1, retry=True, force_new_take=True,
        batch="EP001_CONT_004", grouping="continuity",
    )
    assert result["error"] == "reroll_non_coverage_deferred"


def test_retry_with_batch_deferred_even_with_coverage_grouping(monkeypatch):
    """Merge-gate r10: --retry + --batch is incompatible under ANY grouping."""
    import recoil.pipeline.cli.generate as gen

    for grouping in ("coverage", "auto", "continuity"):
        result = gen.run_generation(
            project="fixture", episode=1, retry=True, force_new_take=True,
            batch="EP001_CONT_004", grouping=grouping,
        )
        assert result["error"] == "reroll_non_coverage_deferred", grouping


def test_cli_failed_take_reports_dispatch_failed(monkeypatch):
    """Merge-gate r12: CLI must not emit a dispatched record for a failed take."""
    from types import SimpleNamespace

    _save_batch_scene()

    async def _fake_run_scene(self, scene, **kwargs):
        beat = scene.beats[0]
        beat.takes.append(
            SimpleNamespace(take_index=len(beat.takes), status="failed")
        )
        return scene

    monkeypatch.setattr(EpisodeRunner, "run_scene", _fake_run_scene)
    monkeypatch.setattr(generate, "StepRunner", lambda store, paths, episode=None: MagicMock())

    result = generate.run_generation(
        project="fixture", episode=1, batch="EP001_CONT_004", force_new_take=True,
    )
    assert result["success"] is False
    assert result["error"] == "dispatch_failed"
    assert result["dispatched"] == []


def test_incompatible_modality_strategy_rejected():
    """Merge-gate r13: start_end_frame (video_i2v-only) must be rejected pre-mutation."""
    import recoil.pipeline.cli.generate as gen

    result = gen.run_generation(
        project="fixture", episode=1, batch="EP001_CONT_004",
        force_new_take=True, strategy="start_end_frame",
    )
    assert result["error"] == "unknown_author_strategy"
