"""Tests for pipeline/lib/frame_uprez.py (Phase 10).

Unit tests that DO NOT hit the live Gemini API. The two Gemini call sites
(`_run_nbp_uprez`, `_run_nb2_uprez`) and the Flash rubric (`_gemini_rubric_check`)
are monkeypatched. Integration/live-model coverage is the pre-ship A/B harness
(pipeline/tools/empirical/uprez_ab.py).
"""

from __future__ import annotations

import io
import sys
from pathlib import Path

import pytest
from PIL import Image

_RECOIL_ROOT = Path(__file__).resolve().parents[3]
if str(_RECOIL_ROOT) not in sys.path:
    sys.path.insert(0, str(_RECOIL_ROOT))
# pipeline/ also on sys.path so frame_uprez's internal `from lib.frame_editor`
# (project-local convention shared with frame_editor.py) resolves.
_PIPELINE_ROOT = Path(__file__).resolve().parents[2]
if str(_PIPELINE_ROOT) not in sys.path:
    sys.path.insert(0, str(_PIPELINE_ROOT))

from recoil.pipeline._lib import frame_uprez as U  # noqa: E402  # type: ignore


# ── Helpers ─────────────────────────────────────────────────────
def _png_bytes(size=(64, 36), color=(120, 140, 160)) -> bytes:
    """Produce a deterministic PNG with a known AR (64:36 = 16:9)."""
    img = Image.new("RGB", size, color=color)
    buf = io.BytesIO()
    img.save(buf, format="PNG")
    return buf.getvalue()


def _perturbed_png(base_bytes: bytes, delta: int = 5) -> bytes:
    """Return a nearly-identical PNG (small brightness shift)."""
    img = Image.open(io.BytesIO(base_bytes)).convert("RGB")
    px = img.load()
    for y in range(img.height):
        for x in range(img.width):
            r, g, b = px[x, y]
            px[x, y] = (min(255, r + delta), min(255, g + delta), min(255, b + delta))
    out = io.BytesIO()
    img.save(out, format="PNG")
    return out.getvalue()


@pytest.fixture
def tmp_image(tmp_path):
    p = tmp_path / "in.png"
    p.write_bytes(_png_bytes())
    return p


# ── Polymorphic input dispatch ──────────────────────────────────

def test_image_source_loads_bytes_directly(tmp_image, monkeypatch):
    """An image path should be loaded verbatim; no ffmpeg called."""
    called = {"ff": False}
    def boom_extract(*args, **kwargs):
        called["ff"] = True
        raise AssertionError("ffmpeg must not be invoked for image source")
    monkeypatch.setattr(U, "_extract_frame_from_video", boom_extract)
    out = U._load_source(tmp_image, frame_select="first")
    assert out.startswith(b"\x89PNG")
    assert called["ff"] is False


def test_video_source_dispatches_to_ffmpeg(tmp_path, monkeypatch):
    """A .mp4 path should route through _extract_frame_from_video."""
    v = tmp_path / "clip.mp4"
    v.write_bytes(b"\x00\x00\x00\x18ftypmp42fake")  # not a real mp4, but passes .is_file
    captured = {}
    def fake_extract(video_path, frame_select):
        captured["path"] = video_path
        captured["select"] = frame_select
        return _png_bytes()
    monkeypatch.setattr(U, "_extract_frame_from_video", fake_extract)
    out = U._load_source(v, frame_select="last")
    assert captured["path"] == v
    assert captured["select"] == "last"
    assert out.startswith(b"\x89PNG")


def test_unsupported_extension_raises(tmp_path):
    p = tmp_path / "weird.xyz"
    p.write_bytes(b"\x00")
    with pytest.raises(ValueError, match="Unsupported source suffix"):
        U._load_source(p, frame_select="first")


def test_missing_source_raises(tmp_path):
    with pytest.raises(FileNotFoundError):
        U._load_source(tmp_path / "nope.png", frame_select="first")


# ── Aspect-ratio parameterization ───────────────────────────────

def test_detect_aspect_ratio_16_9():
    b = _png_bytes(size=(64, 36))
    assert U._detect_aspect_ratio(b) == "16:9"


def test_detect_aspect_ratio_9_16():
    b = _png_bytes(size=(36, 64))
    assert U._detect_aspect_ratio(b) == "9:16"


def test_aspect_ratio_respected_when_explicit(tmp_image, monkeypatch):
    """Caller-supplied aspect_ratio is forwarded verbatim to the engine call."""
    captured = {}
    def fake_nbp(image_bytes, prompt, aspect_ratio):
        captured["ar"] = aspect_ratio
        return {"success": True, "image_data": b"X", "cost": 0.134, "engine_used": "nbp", "model": "m"}
    def fake_validate(a, b):
        return {"passed": True, "layer": "passed", "histogram": {"passed": True, "correlation": 0.99}, "rubric": {"all_yes": True, "axes": {}}}
    monkeypatch.setattr(U, "_run_nbp_uprez", fake_nbp)
    monkeypatch.setattr(U, "_validate_uprez", fake_validate)
    r = U.uprez_frame(tmp_image, aspect_ratio="21:9", engine="nbp", style="photorealistic")
    assert r["success"] is True
    assert captured["ar"] == "21:9"


# ── Engine routing ──────────────────────────────────────────────

def test_auto_selects_nb2_for_stylized_at_supported_ar():
    assert U._auto_select_engine("cartoon_2d", "16:9") == "nb2"
    assert U._auto_select_engine("anime", "9:16") == "nb2"
    assert U._auto_select_engine("stylized", "1:1") == "nb2"


def test_auto_selects_nbp_for_photoreal():
    assert U._auto_select_engine("photorealistic", "16:9") == "nbp"


def test_auto_escalates_to_nbp_when_ar_unsupported():
    assert U._auto_select_engine("cartoon_2d", "21:9") == "nbp"
    assert U._auto_select_engine("anime", "4:3") == "nbp"


def test_explicit_nb2_at_unsupported_ar_rejects(tmp_image, monkeypatch):
    """Explicit engine='nb2' at an AR NB2 cannot serve returns success=False."""
    # Force auto-detect to disagree — pass a 4:3 image.
    p = tmp_image.parent / "ar43.png"
    p.write_bytes(_png_bytes(size=(40, 30)))  # 4:3
    r = U.uprez_frame(p, engine="nb2", style="cartoon_2d")
    assert r["success"] is False
    assert "NB2 does not support" in r["error"]


def test_auto_with_unsupported_ar_silently_escalates(tmp_image, monkeypatch):
    """engine='auto' at a 21:9 source should end up calling NBP, not erroring."""
    p = tmp_image.parent / "ar219.png"
    # 210x90 = 7:3 — not in NB2's {1:1, 9:16, 16:9}, exercises the escalation.
    p.write_bytes(_png_bytes(size=(210, 90)))
    called = {"which": None}
    def fake_nbp(image_bytes, prompt, aspect_ratio):
        called["which"] = "nbp"
        return {"success": True, "image_data": b"X", "cost": 0.134, "engine_used": "nbp", "model": "m"}
    def fake_nb2(image_bytes, prompt, aspect_ratio):
        called["which"] = "nb2"
        return {"success": True, "image_data": b"X", "cost": 0.039, "engine_used": "nb2", "model": "m"}
    def fake_validate(a, b):
        return {"passed": True, "layer": "passed", "histogram": {"passed": True, "correlation": 0.99}, "rubric": {"all_yes": True, "axes": {}}}
    monkeypatch.setattr(U, "_run_nbp_uprez", fake_nbp)
    monkeypatch.setattr(U, "_run_nb2_uprez", fake_nb2)
    monkeypatch.setattr(U, "_validate_uprez", fake_validate)
    r = U.uprez_frame(p, engine="auto", style="cartoon_2d")
    assert r["success"] is True
    assert called["which"] == "nbp"


# ── Validation — rejects drift ──────────────────────────────────

def test_validate_rejects_histogram_drift(tmp_image, monkeypatch):
    """If histogram correlation < 0.85, uprez is rejected at Layer 1."""
    def fake_uprez(image_bytes, prompt, aspect_ratio):
        # Return a clearly-different image (black vs gray original)
        return {"success": True, "image_data": _png_bytes(color=(0, 0, 0)), "cost": 0.039, "engine_used": "nb2", "model": "m"}
    monkeypatch.setattr(U, "_run_nb2_uprez", fake_uprez)
    # Bypass Flash API call entirely — Layer 1 should fail first
    def boom_rubric(*a, **kw):
        raise AssertionError("rubric must not run when histogram fails")
    monkeypatch.setattr(U, "_gemini_rubric_check", boom_rubric)
    r = U.uprez_frame(tmp_image, engine="nb2", style="cartoon_2d")
    assert r["success"] is False
    assert r["validation"]["layer"] == "histogram"
    assert r["validation"]["passed"] is False


def test_validate_rejects_rubric_no(tmp_image, monkeypatch):
    """Histogram passes but Flash says identity changed → reject at Layer 2."""
    def fake_uprez(image_bytes, prompt, aspect_ratio):
        return {"success": True, "image_data": _perturbed_png(_png_bytes()), "cost": 0.039, "engine_used": "nb2", "model": "m"}
    monkeypatch.setattr(U, "_run_nb2_uprez", fake_uprez)
    # Force Layer 1 PASS so Layer 2 (rubric) is exercised. Synthetic flat-color
    # PNGs produce degenerate histograms; we don't want histogram math noise to
    # mask the Layer-2-specific assertion this test is making.
    monkeypatch.setattr(U, "_histogram_check", lambda a, b: {"passed": True, "correlation": 0.99})
    def fake_rubric(a, b):
        return {
            "all_yes": False,
            "axes": {
                "CHARACTER_IDENTITY": {"passed": False, "reason": "face reshaped"},
                "WARDROBE": {"passed": True, "reason": None},
                "SCENE_GEOMETRY": {"passed": True, "reason": None},
                "COMPOSITION": {"passed": True, "reason": None},
                "PALETTE": {"passed": True, "reason": None},
                "LINE_WEIGHT": {"passed": True, "reason": None},
            },
            "cost": 0.003,
        }
    monkeypatch.setattr(U, "_gemini_rubric_check", fake_rubric)
    r = U.uprez_frame(tmp_image, engine="nb2", style="cartoon_2d")
    assert r["success"] is False
    assert r["validation"]["layer"] == "gemini_rubric"
    assert r["validation"]["rubric"]["all_yes"] is False
    assert r["validation"]["rubric"]["axes"]["CHARACTER_IDENTITY"]["passed"] is False


def test_validate_passes_happy_path(tmp_image, monkeypatch):
    """Histogram ≥0.85 and all 6 axes YES → success with validation.passed=True."""
    original = _png_bytes()
    uprezzed = _perturbed_png(original, delta=2)
    def fake_uprez(image_bytes, prompt, aspect_ratio):
        return {"success": True, "image_data": uprezzed, "cost": 0.039, "engine_used": "nb2", "model": "m"}
    monkeypatch.setattr(U, "_run_nb2_uprez", fake_uprez)
    # Force Layer 1 PASS so Layer 2 (rubric) drives the happy-path assertion.
    # Synthetic flat-color PNGs produce degenerate histograms unrelated to the
    # validation-pass behavior under test.
    monkeypatch.setattr(U, "_histogram_check", lambda a, b: {"passed": True, "correlation": 0.99})
    def fake_rubric(a, b):
        return {"all_yes": True, "axes": {ax: {"passed": True, "reason": None} for ax in U._RUBRIC_AXES}, "cost": 0.003}
    monkeypatch.setattr(U, "_gemini_rubric_check", fake_rubric)
    r = U.uprez_frame(tmp_image, engine="nb2", style="cartoon_2d")
    assert r["success"] is True
    assert r["validation"]["passed"] is True
    assert r["engine_used"] == "nb2"
    assert r["image_data"] == uprezzed


# ── Rubric parser ───────────────────────────────────────────────

def test_parse_rubric_all_yes():
    text = (
        "CHARACTER_IDENTITY: YES\n"
        "WARDROBE: YES\n"
        "SCENE_GEOMETRY: YES\n"
        "COMPOSITION: YES\n"
        "PALETTE: YES\n"
        "LINE_WEIGHT: YES\n"
    )
    out = U._parse_rubric_response(text)
    assert out["all_yes"] is True
    assert len(out["axes"]) == 6


def test_parse_rubric_one_no_with_reason():
    text = (
        "1. CHARACTER_IDENTITY: NO — face reshaped\n"
        "2. WARDROBE: YES\n"
        "3. SCENE_GEOMETRY: YES\n"
        "4. COMPOSITION: YES\n"
        "5. PALETTE: YES\n"
        "6. LINE_WEIGHT: YES\n"
    )
    out = U._parse_rubric_response(text)
    assert out["all_yes"] is False
    assert out["axes"]["CHARACTER_IDENTITY"]["passed"] is False
    assert "face reshaped" in (out["axes"]["CHARACTER_IDENTITY"]["reason"] or "")


# ── Prompt selection ────────────────────────────────────────────

def test_nbp_prompt_is_used_for_nbp_engine(tmp_image, monkeypatch):
    captured = {}
    def fake_nbp(image_bytes, prompt, aspect_ratio):
        captured["prompt"] = prompt
        return {"success": True, "image_data": b"X", "cost": 0.134, "engine_used": "nbp", "model": "m"}
    def fake_validate(a, b):
        return {"passed": True, "layer": "passed", "histogram": {"passed": True, "correlation": 0.99}, "rubric": {"all_yes": True, "axes": {}}}
    monkeypatch.setattr(U, "_run_nbp_uprez", fake_nbp)
    monkeypatch.setattr(U, "_validate_uprez", fake_validate)
    U.uprez_frame(tmp_image, engine="nbp", style="photorealistic")
    assert "A highly detailed, perfectly sharp, clean photorealistic cinematic frame" in captured["prompt"]


def test_nb2_prompt_is_used_for_nb2_engine(tmp_image, monkeypatch):
    captured = {}
    def fake_nb2(image_bytes, prompt, aspect_ratio):
        captured["prompt"] = prompt
        return {"success": True, "image_data": b"X", "cost": 0.039, "engine_used": "nb2", "model": "m"}
    def fake_validate(a, b):
        return {"passed": True, "layer": "passed", "histogram": {"passed": True, "correlation": 0.99}, "rubric": {"all_yes": True, "axes": {}}}
    monkeypatch.setattr(U, "_run_nb2_uprez", fake_nb2)
    monkeypatch.setattr(U, "_validate_uprez", fake_validate)
    U.uprez_frame(tmp_image, engine="nb2", style="cartoon_2d")
    # NB2 prompt is shorter and opens with 'Sharpen and clean'
    assert captured["prompt"].startswith("Sharpen and clean this 2D cartoon animation frame")


# ── Invalid args ────────────────────────────────────────────────

def test_invalid_engine_rejected(tmp_image):
    r = U.uprez_frame(tmp_image, engine="seedream")
    assert r["success"] is False
    assert "engine must be one of" in r["error"]


def test_invalid_style_rejected(tmp_image):
    r = U.uprez_frame(tmp_image, style="cyberpunk")
    assert r["success"] is False
    assert "style must be one of" in r["error"]


def test_invalid_frame_select_for_video(tmp_path):
    v = tmp_path / "clip.mp4"
    v.write_bytes(b"\x00")
    with pytest.raises(ValueError, match="frame_select"):
        U._extract_frame_from_video(v, "middle")
