"""
frame_uprez.py — Single-frame quality uprez via Gemini generate_content.

Sibling to frame_editor.py, opposite intent: frame_editor changes content
(pose edit), frame_uprez preserves content (quality restoration). Takes one
image or one video-frame and returns a sharpened/cleaner version with the
same composition, characters, wardrobe, palette, and scene geometry.

Use case: I2V retry start frames extracted by ffmpeg from a prior video.
ffmpeg's lossy re-encode degrades the frame; sending that degraded frame
into Kling/Seedance as a new start frame inherits the artifacts. This
library restores crispness BEFORE the I2V retry dispatches, without
changing what the frame is of.

v0 scope per consultations/recoil/frame-uprezzing-for-retry-path/SYNTHESIS.md:
- Library module only (no CLI, no MCP, no retry-path auto-invocation).
- Polymorphic input (video | image), detected by extension.
- Engine auto-selection (NB2 for stylized, NBP for photoreal; override via kwarg).
- Aspect-ratio parameterization (fixes the hardcoded-9:16 gap in frame_editor.py).
- Two-layer validation (histogram 0.85 + 6-axis Gemini Flash rubric).
- Seedream path DISQUALIFIED (see SYNTHESIS §Where Both Engines Agreed #3).

Cost:
- NB2 path: ~$0.039 per uprez + ~$0.003 validation = ~$0.042 total.
- NBP path: ~$0.134 per uprez + ~$0.003 validation = ~$0.137 total.

Opt-in invocation only. Caller decides when to run.
"""

from __future__ import annotations

import io
import logging
import os
import re
import subprocess
import tempfile
from pathlib import Path
from typing import Optional

from PIL import Image

# Reuse frame_editor's histogram function verbatim for Layer 1 validation —
# same pipeline, same behaviour, just a tighter threshold.
from recoil.pipeline._lib.frame_editor import validate_companion as _histogram_correlation

logger = logging.getLogger(__name__)

# ── Constants ───────────────────────────────────────────────────
NB2_UPREZ_COST = 0.039
NBP_UPREZ_COST = 0.134
FLASH_RUBRIC_COST = 0.003

HISTOGRAM_THRESHOLD = 0.85  # SYNTHESIS §Disagreement B — tightened from 0.70

VALID_ENGINES = frozenset({"auto", "nb2", "nbp"})
VALID_STYLES = frozenset({"auto", "cartoon_2d", "anime", "photorealistic", "stylized"})
VALID_ASPECT_RATIOS_NB2 = frozenset({"1:1", "9:16", "16:9"})   # PROMPT_BIBLE.yaml line 316-319
_FRAME_SELECT_AT_RE = re.compile(r"^at:(\d+(?:\.\d+)?)$")

# ── Prompt templates (verbatim from SYNTHESIS §v0 Build Spec) ───
#
# NBP: full §4j prompt + Opus §3 face/wardrobe preservation tightening.
#   Style anchor and quality keywords parameterized so the same template
#   serves photoreal/anime/stylized without rewriting the whole string.
#
# NB2: shorter prompt per PROMPT_BIBLE §NB2 anti-pattern
#   ("Overloading with detail — cheaper model, less capable of rendering
#   fine details") and SYNTHESIS §Disagreement C.

NBP_UPREZ_PROMPT = (
    "A highly detailed, perfectly sharp, clean {style_anchor} frame. "
    "Redraw this EXACT image with ultra-crisp line art, clean solid color fills, "
    "and sharp background detail. Absolutely NO changes to characters, poses, "
    "expressions, composition, camera angle, lighting, or scene elements. "
    "Same {style_anchor} aesthetic, just crisper and cleaner. "
    "Preserve every character's face shape, eye color, hair style, and wardrobe "
    "exactly as shown. Preserve all background elements — vehicles, props, "
    "signage, environment — without modification. "
    "The only change is QUALITY: sharper edges, cleaner color boundaries, "
    "reduced compression artifacts, higher fidelity rendering. "
    "{quality_keywords}"
)

NB2_UPREZ_PROMPT = (
    "Sharpen and clean this {style_anchor} frame. "
    "Keep every character, pose, expression, and background element EXACTLY as shown. "
    "Only improve quality: crisper lines, cleaner color fills, reduced compression artifacts. "
    "NO changes to faces, wardrobe, composition, camera angle, or scene elements. "
    "Same aesthetic, higher fidelity. {quality_keywords}"
)

STYLE_ANCHORS = {
    "cartoon_2d":       "2D cartoon animation",
    "anime":            "anime illustration",
    "photorealistic":   "photorealistic cinematic",
    "stylized":         "stylized digital painting",
}

QUALITY_KEYWORDS = {
    "cartoon_2d":       "Ultra-crisp, artifact-free, broadcast-quality detail.",
    "anime":            "Clean cel-shading, precise edge definition, vivid flat color.",
    "photorealistic":   "Film-grain-appropriate sharpness, natural skin texture, clean edge definition.",
    "stylized":         "Sharp brushwork definition, clean color boundaries, high-fidelity rendering.",
}

# 6-axis identity rubric prompt for Layer 2 validation (SYNTHESIS §validation).
# Axes from SYNTHESIS: character identity, wardrobe, scene geometry,
# composition, palette, line weight.
UPREZ_VALIDATION_PROMPT = (
    "Compare Figure 1 (before uprez) and Figure 2 (after uprez). "
    "Figure 2 should be a higher-quality version of Figure 1 with NO content changes. "
    "Score each axis YES (preserved) or NO (changed):\n"
    "1. CHARACTER_IDENTITY: Same face shape, eye color, hair, proportions?\n"
    "2. WARDROBE: Same clothing, accessories, colors?\n"
    "3. SCENE_GEOMETRY: Same background structure, environment, props, spatial layout?\n"
    "4. COMPOSITION: Same framing, crop, camera angle, subject placement?\n"
    "5. PALETTE: Same color grade, lighting temperature, contrast?\n"
    "6. LINE_WEIGHT: Same edge/line character (stylization preserved, no 2D→3D or vice versa)?\n"
    "Respond with exactly 6 lines in the form 'AXIS_NAME: YES' or 'AXIS_NAME: NO — [what changed]'."
)

_RUBRIC_AXES = (
    "CHARACTER_IDENTITY", "WARDROBE", "SCENE_GEOMETRY",
    "COMPOSITION", "PALETTE", "LINE_WEIGHT",
)


# ── Input extraction ────────────────────────────────────────────
_IMAGE_SUFFIXES = {".jpg", ".jpeg", ".png", ".webp", ".bmp"}
_VIDEO_SUFFIXES = {".mp4", ".mov", ".mkv", ".webm", ".avi"}


def _extract_frame_from_video(video_path: Path, frame_select: str) -> bytes:
    """Use ffmpeg to extract one frame from a video. Returns PNG bytes.

    frame_select: 'first' | 'last' | 'at:<seconds>'.
    """
    if frame_select == "first":
        ts_args = ["-ss", "0"]
    elif frame_select == "last":
        # Seek near end via -sseof; robust across containers that don't have
        # a perfect duration.
        ts_args = ["-sseof", "-0.1"]
    else:
        m = _FRAME_SELECT_AT_RE.match(frame_select)
        if not m:
            raise ValueError(
                f"frame_select must be 'first' | 'last' | 'at:<seconds>', got {frame_select!r}"
            )
        ts_args = ["-ss", m.group(1)]

    with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
        tmp_path = Path(tmp.name)
    try:
        cmd = [
            "ffmpeg", "-y", "-v", "error",
            *ts_args, "-i", str(video_path),
            "-frames:v", "1", str(tmp_path),
        ]
        subprocess.run(cmd, check=True, stderr=subprocess.PIPE, timeout=30)
        return tmp_path.read_bytes()
    finally:
        try:
            tmp_path.unlink()
        except OSError:
            pass


def _load_source(source: Path, frame_select: str) -> bytes:
    """Return the source image as PNG/JPEG bytes regardless of video/image input."""
    source = Path(source)
    if not source.is_file():
        raise FileNotFoundError(f"uprez source not found: {source}")
    suffix = source.suffix.lower()
    if suffix in _IMAGE_SUFFIXES:
        return source.read_bytes()
    if suffix in _VIDEO_SUFFIXES:
        return _extract_frame_from_video(source, frame_select)
    raise ValueError(
        f"Unsupported source suffix {suffix!r}; expected one of "
        f"{sorted(_IMAGE_SUFFIXES | _VIDEO_SUFFIXES)}"
    )


def _detect_aspect_ratio(image_bytes: bytes) -> str:
    """Return a canonical AR string ('16:9' | '9:16' | '1:1' | 'W:H') from PIL size."""
    img = Image.open(io.BytesIO(image_bytes))
    w, h = img.size
    from math import gcd
    g = gcd(w, h) or 1
    rw, rh = w // g, h // g
    # Collapse common near-misses into the canonical labels.
    ratio = w / h if h else 1.0
    if abs(ratio - 16 / 9) < 0.02:
        return "16:9"
    if abs(ratio - 9 / 16) < 0.02:
        return "9:16"
    if abs(ratio - 1.0) < 0.02:
        return "1:1"
    return f"{rw}:{rh}"


# ── Engine auto-selection (SYNTHESIS §v0 Build Spec) ────────────
def _auto_select_engine(style: str, aspect_ratio: str) -> str:
    """NB2 for stylized/cartoon, NBP for photoreal, NBP when AR unsupported by NB2."""
    if aspect_ratio not in VALID_ASPECT_RATIOS_NB2:
        return "nbp"
    if style == "photorealistic":
        return "nbp"
    return "nb2"


def _resolve_style(style: str) -> str:
    """Collapse 'auto' to 'stylized' as the safe default; other values pass through."""
    if style == "auto":
        return "stylized"
    return style


# ── Engine-specific uprez calls (mirror frame_editor.py patterns, AR parameterized) ──

def _run_gemini_uprez(
    image_bytes: bytes,
    prompt: str,
    aspect_ratio: str,
    *,
    engine: str,
    model_role: str,
    cost: float,
) -> dict:
    """Shared Gemini generate_content call for both NBP and NB2 uprez paths.

    The two engines differ only in (a) `get_model(role)`, (b) per-call cost,
    (c) `engine_used` literal, and (d) error-message engine label. Everything
    else — config shape, AR parameterization, image-bytes extraction — is
    identical, so it lives here. Wrappers `_run_nbp_uprez` / `_run_nb2_uprez`
    remain as the engine-specific entry points (also so tests can monkeypatch
    them by name).
    """
    try:
        from google import genai
        from google.genai import types as genai_types
        from recoil.core.model_profiles import get_model
    except Exception as e:
        return {"success": False, "error": f"genai/model_profiles import failed: {e}"}

    api_key = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY")
    if not api_key:
        return {"success": False, "error": "GEMINI_API_KEY not set"}

    label = engine.upper()
    try:
        client = genai.Client(api_key=api_key)
        hero_image = Image.open(io.BytesIO(image_bytes))
        config = genai_types.GenerateContentConfig(
            temperature=0.3,
            responseModalities=["IMAGE", "TEXT"],
            imageConfig=genai_types.ImageConfig(aspectRatio=aspect_ratio),
        )
        model = get_model(model_role, "image")
        response = client.models.generate_content(
            model=model,
            contents=[hero_image, prompt],
            config=config,
        )
        data = _extract_image_bytes(response)
        if data is None:
            return {"success": False, "error": f"No image in {label} uprez response"}
        return {"success": True, "image_data": data, "cost": cost, "engine_used": engine, "model": model}
    except Exception as e:
        return {"success": False, "error": f"{label} uprez call failed: {e}"}


def _run_nbp_uprez(image_bytes: bytes, prompt: str, aspect_ratio: str) -> dict:
    """Call NBP (gemini-3-pro-image-preview) with the uprez prompt.

    Mirrors frame_editor.edit_hero_pose's request shape but parameterizes
    aspect_ratio (fixes the hardcoded-9:16 gap) and returns a uniform dict.
    """
    return _run_gemini_uprez(
        image_bytes, prompt, aspect_ratio,
        engine="nbp", model_role="production", cost=NBP_UPREZ_COST,
    )


def _run_nb2_uprez(image_bytes: bytes, prompt: str, aspect_ratio: str) -> dict:
    """Call NB2 (gemini-3.1-flash-image-preview) with the uprez prompt.

    NB2 supports only {1:1, 9:16, 16:9} — caller is responsible for routing
    unsupported ARs to NBP via _auto_select_engine. This function assumes
    aspect_ratio is already valid for NB2.
    """
    if aspect_ratio not in VALID_ASPECT_RATIOS_NB2:
        return {
            "success": False,
            "error": f"NB2 does not support aspect_ratio={aspect_ratio!r}; "
                     f"supported: {sorted(VALID_ASPECT_RATIOS_NB2)}",
        }
    return _run_gemini_uprez(
        image_bytes, prompt, aspect_ratio,
        engine="nb2", model_role="exploration", cost=NB2_UPREZ_COST,
    )


def _extract_image_bytes(response) -> Optional[bytes]:
    """Mirror of frame_editor.py's inline_data extraction loop."""
    if response and response.candidates:
        for candidate in response.candidates:
            if candidate.content and candidate.content.parts:
                for part in candidate.content.parts:
                    if hasattr(part, "inline_data") and part.inline_data:
                        return part.inline_data.data
    return None


# ── Validation (Layer 1 histogram + Layer 2 Flash rubric) ───────

def _histogram_check(original_bytes: bytes, uprezzed_bytes: bytes) -> dict:
    """Layer 1: histogram correlation at threshold 0.85 (SYNTHESIS §B)."""
    r = _histogram_correlation(original_bytes, uprezzed_bytes, threshold=HISTOGRAM_THRESHOLD)
    return {"passed": r["passed"], "correlation": r["correlation"]}


def _parse_rubric_response(text: str) -> dict:
    """Parse a 6-line 'AXIS: YES|NO — reason' response into {axis: {passed, reason}}."""
    axes: dict = {}
    for line in (text or "").splitlines():
        line = line.strip()
        if not line:
            continue
        # Tolerate numbered list prefixes like "1. CHARACTER_IDENTITY: YES"
        line = re.sub(r"^\d+\.\s*", "", line)
        if ":" not in line:
            continue
        axis, verdict = line.split(":", 1)
        axis = axis.strip().upper()
        if axis not in _RUBRIC_AXES:
            continue
        verdict = verdict.strip()
        m = re.match(r"(YES|NO)\b[\s\-—:]*(.*)$", verdict, re.IGNORECASE)
        if not m:
            continue
        axes[axis] = {
            "passed": m.group(1).upper() == "YES",
            "reason": m.group(2).strip() or None,
        }
    all_yes = bool(axes) and all(v["passed"] for v in axes.values()) and len(axes) == len(_RUBRIC_AXES)
    return {"axes": axes, "all_yes": all_yes}


def _gemini_rubric_check(original_bytes: bytes, uprezzed_bytes: bytes) -> dict:
    """Layer 2: 6-axis Gemini Flash identity rubric (SYNTHESIS §validation)."""
    try:
        from google import genai
        from google.genai import types as genai_types
    except Exception as e:
        return {"all_yes": False, "error": f"genai import failed: {e}", "axes": {}}

    api_key = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY")
    if not api_key:
        return {"all_yes": False, "error": "GEMINI_API_KEY not set", "axes": {}}

    try:
        client = genai.Client(api_key=api_key)
        before_img = Image.open(io.BytesIO(original_bytes))
        after_img = Image.open(io.BytesIO(uprezzed_bytes))
        config = genai_types.GenerateContentConfig(temperature=0.1)
        response = client.models.generate_content(
            model="gemini-2.5-flash",
            contents=[before_img, after_img, UPREZ_VALIDATION_PROMPT],
            config=config,
        )
        text = ""
        if response and response.candidates:
            for candidate in response.candidates:
                if candidate.content and candidate.content.parts:
                    for part in candidate.content.parts:
                        if hasattr(part, "text") and part.text:
                            text += part.text
        parsed = _parse_rubric_response(text)
        parsed["raw_text"] = text
        parsed["cost"] = FLASH_RUBRIC_COST
        return parsed
    except Exception as e:
        return {"all_yes": False, "error": f"Flash rubric call failed: {e}", "axes": {}}


def _validate_uprez(original_bytes: bytes, uprezzed_bytes: bytes) -> dict:
    """Two-layer validation (SYNTHESIS §validation + §Disagreement B adjudication).

    Returns:
      {
        "passed": bool,
        "layer": "histogram" | "gemini_rubric" | "passed",
        "histogram": {"passed": bool, "correlation": float},
        "rubric": {"axes": {...}, "all_yes": bool, ...}  # present when layer 1 passed
      }

    Gate: histogram (0.85) → PASS → Flash rubric → ALL six YES → accept.
    ANY NO at either layer → reject. Caller is responsible for falling back
    to the original frame on reject.
    """
    hist = _histogram_check(original_bytes, uprezzed_bytes)
    if not hist["passed"]:
        return {"passed": False, "layer": "histogram", "histogram": hist, "rubric": None}
    rubric = _gemini_rubric_check(original_bytes, uprezzed_bytes)
    if not rubric.get("all_yes"):
        return {"passed": False, "layer": "gemini_rubric", "histogram": hist, "rubric": rubric}
    return {"passed": True, "layer": "passed", "histogram": hist, "rubric": rubric}


# ── Public entry point ─────────────────────────────────────────
def uprez_frame(
    source: Path,
    *,
    aspect_ratio: str | None = None,
    engine: str = "auto",
    style: str = "auto",
    frame_select: str = "first",
    validate: bool = True,
    project: str | None = None,
) -> dict:
    """Uprez a single image OR one frame of a video.

    Args:
        source: Path to an image (jpg/png/webp/...) OR a video (mp4/mov/...).
          Polymorphic — detected by extension.
        aspect_ratio: Canonical AR string ("16:9" | "9:16" | "1:1" | "W:H").
          If None, auto-detect from the source image size. Defaults to None.
          NB2 supports only {1:1, 9:16, 16:9}; unsupported ARs auto-route to NBP.
        engine: "auto" | "nb2" | "nbp". Default "auto" applies the rule:
          NB2 for stylized content at supported ARs, NBP otherwise.
        style: "auto" | "cartoon_2d" | "anime" | "photorealistic" | "stylized".
          "auto" collapses to "stylized" as a safe default.
        frame_select: "first" | "last" | "at:<seconds>". Video only.
          Default "first" matches the I2V retry-path use case.
        validate: Run two-layer validation. Default True (strongly recommended).
          Caller falls back to original on validation reject.
        project: Optional project name, threaded through for future logging
          hooks. Not used in v0.

    Returns:
        On success: {
          "success": True,
          "image_data": bytes,
          "engine_used": "nb2" | "nbp",
          "model": "gemini-...",
          "cost": float,                 # uprez + validation
          "aspect_ratio": str,           # resolved AR actually sent
          "style": str,                  # resolved style
          "validation": {...} | None,    # None when validate=False
        }
        On failure: {
          "success": False,
          "error": str,
          "engine_used": str | None,
          "validation": {...} | None,
        }
    """
    if engine not in VALID_ENGINES:
        return {"success": False, "error": f"engine must be one of {sorted(VALID_ENGINES)}, got {engine!r}",
                "engine_used": None, "validation": None}
    if style not in VALID_STYLES:
        return {"success": False, "error": f"style must be one of {sorted(VALID_STYLES)}, got {style!r}",
                "engine_used": None, "validation": None}

    # Load source → bytes (video-frame-extract if needed).
    try:
        original_bytes = _load_source(Path(source), frame_select=frame_select)
    except (FileNotFoundError, ValueError, subprocess.CalledProcessError, subprocess.TimeoutExpired) as e:
        return {"success": False, "error": f"source load failed: {e}", "engine_used": None, "validation": None}

    # Resolve AR (explicit caller value wins; else auto-detect).
    resolved_ar = aspect_ratio or _detect_aspect_ratio(original_bytes)
    resolved_style = _resolve_style(style)

    # Engine resolution.
    chosen = engine
    if chosen == "auto":
        chosen = _auto_select_engine(resolved_style, resolved_ar)
    # Hard escalation: NB2 cannot serve unsupported ARs — fail fast, don't silently switch
    # when the caller *explicitly* asked for NB2 at an unsupported AR.
    if chosen == "nb2" and resolved_ar not in VALID_ASPECT_RATIOS_NB2:
        if engine == "auto":
            chosen = "nbp"  # auto: escalate silently per SYNTHESIS
        else:
            return {
                "success": False,
                "error": f"NB2 does not support aspect_ratio={resolved_ar!r}; "
                         f"pass engine='nbp' or an AR in {sorted(VALID_ASPECT_RATIOS_NB2)}",
                "engine_used": None, "validation": None,
            }

    # Build the prompt.
    style_anchor = STYLE_ANCHORS.get(resolved_style, STYLE_ANCHORS["stylized"])
    quality_kw = QUALITY_KEYWORDS.get(resolved_style, QUALITY_KEYWORDS["stylized"])
    template = NBP_UPREZ_PROMPT if chosen == "nbp" else NB2_UPREZ_PROMPT
    prompt = template.format(style_anchor=style_anchor, quality_keywords=quality_kw)

    # Dispatch the uprez call.
    if chosen == "nbp":
        call = _run_nbp_uprez(original_bytes, prompt, resolved_ar)
    else:
        call = _run_nb2_uprez(original_bytes, prompt, resolved_ar)

    if not call.get("success"):
        return {
            "success": False,
            "error": call.get("error", "uprez call failed"),
            "engine_used": chosen,
            "validation": None,
        }

    uprezzed_bytes = call["image_data"]
    total_cost = float(call.get("cost", 0.0))
    validation = None
    if validate:
        validation = _validate_uprez(original_bytes, uprezzed_bytes)
        total_cost += FLASH_RUBRIC_COST if validation.get("rubric") else 0.0
        if not validation.get("passed"):
            return {
                "success": False,
                "error": f"uprez rejected by validation (layer={validation.get('layer')})",
                "engine_used": chosen,
                "validation": validation,
                "cost": round(total_cost, 4),
            }

    return {
        "success": True,
        "image_data": uprezzed_bytes,
        "engine_used": chosen,
        "model": call.get("model"),
        "cost": round(total_cost, 4),
        "aspect_ratio": resolved_ar,
        "style": resolved_style,
        "validation": validation,
    }
