from __future__ import annotations

import json
import sys

import pytest

from recoil.pipeline.cli import generate


PROJECT = "fixture"
PASS_ID = "PASS_011"


def _coverage_pass_dict() -> dict:
    return {
        "pass_id": PASS_ID,
        "episode_id": "ep_001",
        "shot_range": ["EP001_SH01", "EP001_SH01"],
        "camera_side": "A",
        "label": "fixture",
        "focus_character": "",
        "pass_type": "env",
        "location_id": "L1",
        "generation_config": {"mode": "t2v"},
        "segments": [
            {
                "segment_index": 0,
                "source_shot_id": "EP001_SH01",
                "shot_type": "MS",
                "duration_s": 2,
                "prompt": "shot EP001_SH01",
            }
        ],
    }


@pytest.fixture(autouse=True)
def _projects_root(tmp_path, monkeypatch):
    root = tmp_path / "projects"
    root.mkdir()
    (root / ".recoil-data-root").touch()
    (root / PROJECT).mkdir()
    monkeypatch.setenv("RECOIL_PROJECTS_ROOT", str(root))
    return root


def _run_main(monkeypatch, capsys, *args: str) -> tuple[int, dict]:
    monkeypatch.setattr(
        sys,
        "argv",
        ["generate.py", "--project", PROJECT, "--episode", "1", *args],
    )
    code = generate.main()
    return code, json.loads(capsys.readouterr().out)


def test_generate_retry_absent_record_fails_loud_before_dispatch(
    monkeypatch,
    capsys,
):
    class FakePassStore:
        def __init__(self, project):
            self.project = project

        def get_pass(self, pass_id):
            assert pass_id == PASS_ID
            return None

    def fail_load(*args, **kwargs):
        raise AssertionError("generation should not proceed without --force-retry")

    monkeypatch.setattr(generate, "PassStore", FakePassStore)
    monkeypatch.setattr(generate, "_load_coverage_pass_dicts", fail_load)

    code, result = _run_main(
        monkeypatch,
        capsys,
        "--pass",
        PASS_ID,
        "--retry",
        "--dry-run",
    )

    assert code == generate.EXIT_VALIDATION
    assert result["success"] is False
    assert result["error"] == "retry_record_missing"
    assert result["pass_id"] == PASS_ID
    assert "--force-retry" in result["message"]


def test_generate_force_retry_absent_record_proceeds(monkeypatch, capsys):
    pass_store_calls = []
    load_calls = []

    class FakePassStore:
        def __init__(self, project):
            pass_store_calls.append(project)

        def get_pass(self, pass_id):
            assert pass_id == PASS_ID
            return None

    def load_passes(paths, episode):
        load_calls.append((paths.project, episode))
        return [_coverage_pass_dict()]

    monkeypatch.setattr(generate, "PassStore", FakePassStore)
    monkeypatch.setattr(generate, "_load_coverage_pass_dicts", load_passes)
    monkeypatch.setattr(generate, "validate_all_passes", lambda passes: [])

    code, result = _run_main(
        monkeypatch,
        capsys,
        "--pass",
        PASS_ID,
        "--force-retry",
        "--dry-run",
    )

    assert code == generate.EXIT_OK
    assert pass_store_calls == [PROJECT]
    assert load_calls == [(PROJECT, 1)]
    assert result["success"] is True
    assert result["dry_run"] is True


@pytest.mark.parametrize(
    ("status", "error"),
    [
        ("completed", "already_completed"),
        ("generating", "orphaned_in_flight"),
    ],
)
def test_generate_force_retry_keeps_terminal_retry_guards(
    monkeypatch,
    capsys,
    status,
    error,
):
    class FakePassStore:
        def __init__(self, project):
            self.project = project

        def get_pass(self, pass_id):
            assert pass_id == PASS_ID
            return {"pass_id": pass_id, "status": status}

    def fail_load(*args, **kwargs):
        raise AssertionError("--force-retry must not bypass completed/generating")

    monkeypatch.setattr(generate, "PassStore", FakePassStore)
    monkeypatch.setattr(generate, "_load_coverage_pass_dicts", fail_load)

    code, result = _run_main(
        monkeypatch,
        capsys,
        "--pass",
        PASS_ID,
        "--force-retry",
        "--dry-run",
    )

    assert code == generate.EXIT_VALIDATION
    assert result["success"] is False
    assert result["error"] == error
    assert result["pass_id"] == PASS_ID
