from __future__ import annotations

import base64
import re
from pathlib import Path
from typing import Any

import pytest

from recoil.pipeline._lib import dispatch_payload as dp
from recoil.pipeline._lib.author_strategies import resolve_strategy
from recoil.pipeline._lib.bible_loader import get_optimal_word_range
from recoil.pipeline._lib.plan_loader import CanonicalShot, CharacterEntry
from recoil.pipeline._lib.prose_validator import Severity, verify_authored_prose
from recoil.pipeline._lib.shot_primitive import ShotPrimitive


FIXTURE_DIR = Path(__file__).with_name("fixtures") / "rec72"
PNG_SIGNATURE = b"\x89PNG\r\n\x1a\n"
TIMECODE_RE = re.compile(
    r"\[(?P<sm>\d+):(?P<ss>\d{2})\s*-\s*(?P<em>\d+):(?P<es>\d{2})\]"
)
IMAGEN_RE = re.compile(r"@Image(\d+)")

CHARACTER_CONFIG = {
    "JADE": {
        "display_name": "Jade",
        "aliases": ["Jade Cooper"],
        "role": "protagonist",
    },
    "WREN": {
        "display_name": "Wren",
        "aliases": ["Wren Vale"],
        "role": "support",
    },
}
LOCATION_CONFIG = {
    "int_lower_decks_maintenance_shaft": {
        "display_name": "Lower Decks Maintenance Shaft",
    },
}

DIRECTED_PROSE_FIXTURE = (
    "[0:00-0:04] The handheld lens pushes low beside Wren's brushed-steel "
    "hand as he clamps the anchor cable, shoulder servos ticking while Jade "
    "holds behind him in Lower Decks Maintenance Shaft, eyes narrowed and "
    "breath measured against the wind.\n"
    "[0:04-0:08] A rack focus pulls from the cable bite to Jade as she steps "
    "past Wren, jaw tight and fingers white on the tether, forcing the move "
    "left-to-right while Wren turns his plated body to shield her."
)

START_END_BRIEF_FIXTURE = (
    "Hold the start frame composition as the sealed pod hisses open under a "
    "slow push-in; vapor spills from the seam, status lights pulse colder, "
    "and the hatch rolls back into the final open-pod frame."
)


@pytest.fixture(autouse=True)
def _patch_project_config(monkeypatch: pytest.MonkeyPatch):
    dp._project_config_cache.clear()
    monkeypatch.setattr(dp, "load_project_config", lambda _project: {})
    yield
    dp._project_config_cache.clear()


def test_directed_prose_acceptance_on_tartarus_r2v_multi(
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    manifest = {"identity_1": 1, "identity_2": 2}
    ref_images = [
        str(FIXTURE_DIR / "ep002_jade_ref.png"),
        str(FIXTURE_DIR / "ep002_wren_ref.png"),
    ]
    for ref in ref_images:
        assert Path(ref).read_bytes().startswith(PNG_SIGNATURE)

    monkeypatch.setattr(
        dp,
        "_collect_reference_images",
        lambda *args, **kwargs: (ref_images, manifest),
    )
    captured: dict[str, Any] = {}

    def fake_author_pass(primitive, strategy, **_kwargs):
        captured["primitive"] = primitive
        captured["strategy"] = strategy
        return DIRECTED_PROSE_FIXTURE

    monkeypatch.setattr(dp, "author_pass", fake_author_pass)
    wren = _r2v_shot(
        "EP002_SH01",
        "WREN",
        shot_type="ECU",
        action_line="Wren locks his mechanical hand on the anchor cable.",
    )
    jade = _r2v_shot(
        "EP002_SH02",
        "JADE",
        shot_type="MS",
        action_line="Jade steps past Wren onto the cable.",
    )

    payload = dp.build_unified_payload(
        dp.PayloadContext(
            project="tartarus",
            modality="r2v_multi",
            shot_id="EP002_PASS_009",
            shot=wren,
            batch_shots=[wren, jade],
            model_id="seeddance-2.0",
            bible={},
        )
    )

    primitive = captured["primitive"]
    strategy = captured["strategy"]
    prompt = payload["prompt"]

    assert strategy.name == "directed_prose"
    assert payload["provider_hints"]["r2v_multi"] is True
    assert payload["reference_images"] == ref_images
    assert _configured_character_name_leaks(prompt, CHARACTER_CONFIG) == set()
    assert "Lower Decks Maintenance Shaft" in prompt
    assert _beat_count(prompt) == len(primitive.timing_segments)
    assert _timecode_duration(prompt) == pytest.approx(
        primitive.target_editorial_duration_s,
        abs=0.5,
    )
    assert _image_numbers(prompt) <= set(manifest.values())
    assert _image_numbers(prompt)


def test_start_end_frame_acceptance_on_ep001_sh11_pod_hisses_open(
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    start = FIXTURE_DIR / "ep001_sh11_start.png"
    end = FIXTURE_DIR / "ep001_sh11_end.png"
    assert start.read_bytes().startswith(PNG_SIGNATURE)
    assert end.read_bytes().startswith(PNG_SIGNATURE)

    captured: dict[str, Any] = {}

    def fake_author_pass(primitive, strategy, **_kwargs):
        captured["primitive"] = primitive
        captured["strategy"] = strategy
        return START_END_BRIEF_FIXTURE

    monkeypatch.setattr(dp, "author_pass", fake_author_pass)
    monkeypatch.setattr(dp, "_collect_reference_images", lambda *a, **k: ([], {}))
    shot = _i2v_shot(start, end)

    primitive = ShotPrimitive(
        shot_id="EP001_SH11",
        scene_index=11,
        shot_type="CU",
        target_editorial_duration_s=5.0,
        intent="pod_hisses_open: the sealed cryo pod transforms into the open final frame.",
        refs={"start_frame": str(start), "end_frame": str(end)},
    )
    modality, strategy = resolve_strategy(
        primitive,
        model_id="kling-v3",
        requested_modality="video_i2v",
    )

    payload = dp.build_unified_payload(
        dp.PayloadContext(
            project="tartarus",
            modality="video_i2v",
            shot_id="EP001_SH11",
            start_frame_path=start,
            end_frame_path=end,
            shot=shot,
            model_id="kling-v3",
            bible={},
        )
    )

    assert strategy.name == "start_end_frame"
    assert modality == "video_i2v"
    assert captured["strategy"].name == "start_end_frame"
    assert captured["primitive"].refs["start_frame"] == str(start)
    assert captured["primitive"].refs["end_frame"] == str(end)
    assert "@Image" not in payload["prompt"]
    assert payload["start_frame"] == str(start)
    assert payload["image_tail"]
    assert base64.b64decode(payload["image_tail"]) == end.read_bytes()
    lo, hi = get_optimal_word_range(payload["model"], mode="i2v")
    assert lo <= _word_count(payload["prompt"]) <= hi

    verify_results = verify_authored_prose(
        START_END_BRIEF_FIXTURE,
        captured["primitive"],
        captured["strategy"],
    )
    assert not [r for r in verify_results if r.severity == Severity.BLOCK]


def _r2v_shot(
    shot_id: str,
    char_id: str,
    *,
    shot_type: str,
    action_line: str,
) -> CanonicalShot:
    raw = {
        "shot_id": shot_id,
        "scene_index": 2,
        "duration_s": 4.0,
        "shot_type": shot_type,
        "camera_side": "A",
        "screen_direction": "left-to-right",
        "source_text": action_line,
        "asset_data": {
            "characters": [{"char_id": char_id, "wardrobe_phase_id": "base"}],
            "location_id": "int_lower_decks_maintenance_shaft",
        },
        "prompt_data": {
            "shot_type": shot_type,
            "prompt_skeleton": {
                "action_line": action_line,
                "emotion_line": "Controlled urgency under pressure.",
            },
        },
        "refs": {
            "characters": CHARACTER_CONFIG,
            "locations": LOCATION_CONFIG,
        },
    }
    return CanonicalShot(
        shot_id=shot_id,
        scene_index=2,
        sequence_id=None,
        pipeline="video",
        previs_model="gemini-3.1-flash-image-preview",
        video_model="seeddance-2.0",
        location_id="int_lower_decks_maintenance_shaft",
        characters=[CharacterEntry(char_id=char_id, wardrobe_phase_id="base")],
        shot_type=shot_type,
        duration_s=4.0,
        is_env_only=False,
        has_dialogue=False,
        aspect_ratio="9:16",
        raw=raw,
    )


def _i2v_shot(start: Path, end: Path) -> CanonicalShot:
    raw = {
        "shot_id": "EP001_SH11",
        "scene_index": 11,
        "duration_s": 5.0,
        "shot_type": "CU",
        "start_frame": str(start),
        "end_frame": str(end),
        "source_text": "pod_hisses_open: the cryo pod hisses open from sealed start to open end.",
        "asset_data": {
            "characters": [],
            "location_id": "int_lower_decks_maintenance_shaft",
        },
        "prompt_data": {
            "shot_type": "CU",
            "prompt_skeleton": {
                "action_line": "The pod hisses open between the committed start and end frames.",
                "emotion_line": "Mechanical pressure releases without overexplaining the in-between.",
            },
        },
    }
    return CanonicalShot(
        shot_id="EP001_SH11",
        scene_index=11,
        sequence_id=None,
        pipeline="video",
        previs_model="gemini-3.1-flash-image-preview",
        video_model="kling-v3",
        location_id="int_lower_decks_maintenance_shaft",
        characters=[],
        shot_type="CU",
        duration_s=5.0,
        is_env_only=True,
        has_dialogue=False,
        aspect_ratio="9:16",
        raw=raw,
    )


def _beat_count(text: str) -> int:
    return len(TIMECODE_RE.findall(text))


def _timecode_duration(text: str) -> float:
    total = 0.0
    for match in TIMECODE_RE.finditer(text):
        start_s = int(match.group("sm")) * 60 + int(match.group("ss"))
        end_s = int(match.group("em")) * 60 + int(match.group("es"))
        total += end_s - start_s
    return total


def _image_numbers(text: str) -> set[int]:
    return {int(n) for n in IMAGEN_RE.findall(text)}


def _configured_character_name_leaks(
    text: str,
    characters: dict[str, dict[str, Any]],
) -> set[str]:
    leaks: set[str] = set()
    for cid, config in characters.items():
        candidates = {cid, cid.title()}
        display = config.get("display_name")
        if display:
            candidates.add(str(display))
        aliases = config.get("aliases") or []
        for alias in aliases:
            candidates.add(str(alias))
        for candidate in candidates:
            if re.search(rf"\b{re.escape(candidate)}\b", text):
                leaks.add(candidate)
    return leaks


def _word_count(text: str) -> int:
    return len(re.findall(r"\b[\w'-]+\b", text))
