from __future__ import annotations

import json
import importlib
import sys
import types
from types import SimpleNamespace

import pytest

from recoil.core.paths import PIPELINE_ROOT


_previous_sys_path = list(sys.path)
_previous_api_modules = {
    m: sys.modules[m] for m in list(sys.modules) if m == "api" or m.startswith("api.")
}
try:
    _pipeline_root = str(PIPELINE_ROOT)
    if _pipeline_root in sys.path:
        sys.path.remove(_pipeline_root)
    sys.path.insert(0, _pipeline_root)
    for _stale in [m for m in sys.modules if m == "api" or m.startswith("api.")]:
        del sys.modules[_stale]

    import api  # noqa: E402

    _routes_pkg = types.ModuleType("api.routes")
    _routes_pkg.__path__ = [str(PIPELINE_ROOT / "api" / "routes")]
    sys.modules["api.routes"] = _routes_pkg
    setattr(api, "routes", _routes_pkg)
    generation = importlib.import_module("api.routes.generation")
finally:
    sys.path[:] = _previous_sys_path
    for _stale in [m for m in sys.modules if m == "api" or m.startswith("api.")]:
        del sys.modules[_stale]
    sys.modules.update(_previous_api_modules)


SHOT_ID = "EP001_SH001"
NO_PROMPT_ERROR = (
    f"no prompt resolvable for shot {SHOT_ID} — refusing paid dispatch"
)


class _Store:
    def __init__(self):
        self.shot = {
            "shot_id": SHOT_ID,
            "status": "keyframe_approved",
            "gate_results": {"hero_frame": "frames/hero.png"},
        }
        self.updates = []

    def get_shot(self, shot_id):
        assert shot_id == SHOT_ID
        return self.shot

    def update_shot(self, shot_id, **kwargs):
        assert shot_id == SHOT_ID
        self.updates.append(kwargs)
        self.shot.update(kwargs)
        return self.shot


class _GenerationTracker:
    def __init__(self):
        self.active = set()

    def try_start(self, shot_id):
        self.active.add(shot_id)
        return True

    def finish(self, shot_id):
        self.active.discard(shot_id)


class _AvailableClient:
    def is_available(self):
        return True


@pytest.fixture
def route_env(tmp_path, monkeypatch):
    project_dir = tmp_path / "project"
    frame_path = project_dir / "frames" / "hero.png"
    frame_path.parent.mkdir(parents=True)
    frame_path.write_bytes(b"frame")
    plans_dir = tmp_path / "plans"
    plans_dir.mkdir()

    store = _Store()
    task_records = {}

    def submit_immediately(entity_id, action, fn, *args, metadata=None, **kwargs):
        task_id = "task-1"
        task_records[task_id] = {
            "task_id": task_id,
            "entity_id": entity_id,
            "action": action,
            "status": "running",
            "result": None,
            "error": None,
            "metadata": metadata or {},
        }
        try:
            result = fn(*args, **kwargs)
        except Exception as exc:
            task_records[task_id]["status"] = "failed"
            task_records[task_id]["error"] = str(exc)
        else:
            task_records[task_id]["status"] = "complete"
            task_records[task_id]["result"] = result
        return task_id

    monkeypatch.setattr(generation, "submit_task", submit_immediately)
    monkeypatch.setattr(generation, "gen_tracker", _GenerationTracker())
    monkeypatch.setattr(
        generation, "get_project_aspect_ratio", lambda project: "16:9"
    )

    import recoil.execution.api_client as api_client
    import recoil.execution.step_runner as step_runner
    import recoil.execution.step_types as step_types

    monkeypatch.setattr(api_client, "FalAiKlingClient", _AvailableClient)
    monkeypatch.setattr(step_runner, "StepRunner", lambda **kwargs: object())
    monkeypatch.setattr(
        step_types.ProjectPaths,
        "for_episode",
        classmethod(lambda cls, project, episode: SimpleNamespace(project=project)),
    )

    return SimpleNamespace(
        paths={"project_dir": project_dir, "plans_dir": plans_dir},
        store=store,
        task_records=task_records,
        plan_path=plans_dir / "ep_001_plan.json",
    )


def _invoke_generate_video(route_env):
    response = generation.generate_video(
        body={"shot_id": SHOT_ID, "prompt": ""},
        project="fixture",
        paths=route_env.paths,
        store=route_env.store,
    )
    assert response.status_code == 202
    body = json.loads(response.body)
    return route_env.task_records[body["task_id"]]


@pytest.mark.parametrize("plan_text", [None, "{not-json"])
def test_generate_video_no_prompt_fails_before_paid_dispatch(
    route_env, monkeypatch, plan_text
):
    if plan_text is not None:
        route_env.plan_path.write_text(plan_text, encoding="utf-8")

    dispatch_mod = importlib.import_module("recoil.pipeline.core.dispatch")

    monkeypatch.setattr(
        dispatch_mod,
        "dispatch",
        lambda *args, **kwargs: pytest.fail("dispatch should not be called"),
    )

    task = _invoke_generate_video(route_env)

    assert task["status"] == "failed"
    assert task["error"] == NO_PROMPT_ERROR
    assert {
        "status": "video_failed",
        "error_message": NO_PROMPT_ERROR,
    } in route_env.store.updates


def test_generate_video_uses_resolvable_plan_prompt(route_env, monkeypatch):
    plan_prompt = "Jinx runs through the corridor as sparks fall."
    route_env.plan_path.write_text(
        json.dumps(
            {
                "shots": [
                    {
                        "shot_id": SHOT_ID,
                        "prompt_data": {
                            "prompt_skeleton": {"action_line": plan_prompt}
                        },
                    }
                ]
            }
        ),
        encoding="utf-8",
    )

    dispatch_mod = importlib.import_module("recoil.pipeline.core.dispatch")

    dispatch_calls = []

    def dispatch_success(kind, payload, *, context):
        dispatch_calls.append((kind, payload, context))
        return SimpleNamespace(
            run_result=SimpleNamespace(
                success=True,
                error=None,
                cost_usd=1.23,
            )
        )

    monkeypatch.setattr(dispatch_mod, "dispatch", dispatch_success)

    task = _invoke_generate_video(route_env)

    assert task["status"] == "complete"
    assert task["result"] == {"shot_id": SHOT_ID}
    assert len(dispatch_calls) == 1
    assert dispatch_calls[0][0] == "video_i2v"
    assert dispatch_calls[0][1]["prompt"] == plan_prompt
