#!/usr/bin/env python3
"""
Test NBP direct rendering — V4: Triptych + scene reference chaining.

Changes from V3:
  - Shot 1 (ENV) generates first, its hero panel becomes the SCENE REFERENCE
    for shots 2 and 3 — locking the corridor across all shots
  - Dropped storyboard blocking/facing injection (caused direct-to-camera)
  - Explicit "never direct-to-camera" guard
  - Scene reference image labeled clearly in prompt so NBP knows what it is

Usage:
    python3 tools/test_nbp_direct.py leviathan/ --episode 1 --shots 1-3
"""

import argparse
import json
import os
import re
import sys
import time
from io import BytesIO
from pathlib import Path

MODEL_PRO = "gemini-3-pro-image-preview"

# Jinx reference images (curated picks + keystone)
JINX_REFS = [
    "visual/lora_candidates/JINX/keystones/Jinx_Hero.jpeg",
    "visual/lora_candidates/JINX/picks/closeup_front_neutral_p3.png",
    "visual/lora_candidates/JINX/picks/front_focused_p3.png",
]


# ── ENV Sanitization ──────────────────────────────────────────────────

_HUMAN_PATTERNS = [
    re.compile(r",?\s*a\s+figure'?s?\s+\w+\s+visible\s+[^,\.]*[,\.]?", re.I),
    re.compile(r",?\s*\b(?:a|the)\s+(?:figure|person|silhouette|someone|somebody)\b[^,\.]*[,\.]?", re.I),
    re.compile(r"\b(?:her|his|their)\s+(?:arms?|hands?|fingers?|face|body|boots?|feet|foot|legs?|torso|shoulders?)\b", re.I),
    re.compile(r"(?:^|\.\s*)\b(?:She|He)\s+[^\.]+\.", re.I),
    re.compile(r"\b(?:both|two)\s+(?:arms?|hands?|fingers?|legs?|feet)\b", re.I),
    re.compile(r"\bthe\s+figure\b", re.I),
    re.compile(r"\bfigure'?s?\b", re.I),
]

_CLEANUP = [
    (re.compile(r",\s*,"), ","),
    (re.compile(r"\s{2,}"), " "),
    (re.compile(r",\s*\."), "."),
    (re.compile(r"^\s*,\s*"), ""),
]


def sanitize_env_prompt(text: str) -> str:
    """Strip human-presence language from environment-only frame descriptions."""
    result = text
    for pattern in _HUMAN_PATTERNS:
        result = pattern.sub("", result)
    for pattern, replacement in _CLEANUP:
        result = pattern.sub(replacement, result)
    return result.strip()


# ── Data Loading ──────────────────────────────────────────────────────

def load_storyboard(project_path: Path, episode: int) -> dict:
    sb_path = project_path / "storyboards" / f"storyboard_ep_{episode:03d}.json"
    if not sb_path.exists():
        print(f"ERROR: Storyboard not found: {sb_path}")
        sys.exit(1)
    with open(sb_path) as f:
        try:
            return json.load(f)
        except json.JSONDecodeError as e:
            print(f"ERROR: Invalid JSON in {sb_path}: {e}")
            sys.exit(1)


def load_breakdown(project_path: Path) -> dict:
    bd_path = project_path / "visual" / "breakdown.json"
    if not bd_path.exists():
        print(f"ERROR: Breakdown not found: {bd_path}")
        sys.exit(1)
    with open(bd_path) as f:
        try:
            return json.load(f)
        except json.JSONDecodeError as e:
            print(f"ERROR: Invalid JSON in {bd_path}: {e}")
            sys.exit(1)


def load_reference_images(project_path: Path) -> list:
    """Load Jinx reference image bytes as genai Parts."""
    from google.genai import types

    refs = []
    for rel in JINX_REFS:
        p = project_path / rel
        if not p.exists():
            print(f"  WARNING: Reference image not found: {p}")
            continue
        mime = "image/jpeg" if p.suffix.lower() in (".jpg", ".jpeg") else "image/png"
        refs.append(types.Part(inline_data=types.Blob(
            mime_type=mime, data=p.read_bytes()
        )))
        print(f"  Loaded ref: {p.name} ({p.stat().st_size / 1024:.0f} KB)")
    return refs


# ── Triptych Prompt ───────────────────────────────────────────────────

def build_triptych_prompt(shot: dict, storyboard: dict, is_env: bool,
                          jinx_desc: str = "",
                          has_scene_ref: bool = False) -> str | None:
    """Build a single prompt that generates a 3-panel triptych strip.

    Each panel is described explicitly. Shared elements (environment,
    lighting, camera, character) are stated once at the top so the model
    treats all three panels as the same scene at different moments.

    Returns the prompt text only — reference image labels are handled
    in the parts assembly (build_parts_for_shot).
    """
    first_text = shot.get("first_frame", "")
    hero_text = shot.get("hero_frame", "")
    last_text = shot.get("last_frame", "")

    if not hero_text:
        return None

    # Sanitize ENV shots
    if is_env:
        first_text = sanitize_env_prompt(first_text) if first_text else ""
        hero_text = sanitize_env_prompt(hero_text)
        last_text = sanitize_env_prompt(last_text) if last_text else ""

    # Determine panel count (2 if last_frame is empty)
    has_last = bool(last_text.strip())
    panel_count = 3 if has_last else 2

    # Shared scene context
    cinematic = storyboard.get("cinematic", "")
    location = storyboard.get("location", "")
    lighting = shot.get("lighting", "")
    atmosphere = shot.get("atmosphere", "")
    action = shot.get("action", "")
    emotion = shot.get("emotion", "")
    camera = (f"{shot['shot_type']} shot, {shot['camera_angle']} angle, "
              f"{shot.get('focal_length', '50mm')} at {shot.get('aperture', 'f/2.0')}")

    # Action/emotion emphasis — these drive the image, not just describe it
    action_block = ""
    if action or emotion:
        action_block = "\nACTION & EMOTION (this is what's HAPPENING):\n"
        if action:
            action_block += f"Action: {action}\n"
        if emotion:
            action_block += f"Emotion: {emotion}\n"
        action_block += (
            "The character must be DOING this — body in motion, muscles "
            "engaged, expression showing this emotion. Not posed. Not static. "
            "Caught mid-action.\n"
        )

    # ENV negative
    env_block = ""
    if is_env:
        env_block = (
            "\nCRITICAL: This is an ENVIRONMENT-ONLY shot. "
            "ABSOLUTELY NO PEOPLE in ANY panel. No human figures, no silhouettes, "
            "no hands, no arms, no body parts, no faces. "
            "Only physical environment, objects, and atmosphere.\n"
        )

    # Camera direction guard
    camera_guard = ""
    if not is_env:
        camera_guard = (
            "\nCAMERA DIRECTION: The character should NOT look directly at "
            "the camera. No direct-to-camera eye contact. Use 3/4 angles, "
            "profile views, or over-shoulder compositions. The camera is an "
            "invisible observer.\n"
        )

    # Build panel descriptions
    if panel_count == 3:
        panel_section = (
            f"PANEL 1 (LEFT) — the moment before:\n{first_text}\n\n"
            f"PANEL 2 (CENTER) — the peak action:\n{hero_text}\n\n"
            f"PANEL 3 (RIGHT) — the moment after:\n{last_text}\n"
        )
        layout_instruction = (
            "Generate a SINGLE wide image containing exactly THREE panels "
            "arranged side by side, left to right. "
            "Each panel is a separate cinematic still from the SAME scene "
            "at three sequential moments. "
            "Panels must share the SAME environment, SAME lighting, SAME "
            "color palette, and SAME camera angle. The corridor walls, rust "
            "texture, pipe placement, and lighting direction must be identical "
            "across all panels — only the action changes. "
            "Separate panels with a thin black vertical divider line."
        )
    else:
        panel_section = (
            f"PANEL 1 (LEFT) — the moment before:\n{first_text}\n\n"
            f"PANEL 2 (RIGHT) — the peak action:\n{hero_text}\n"
        )
        layout_instruction = (
            "Generate a SINGLE wide image containing exactly TWO panels "
            "arranged side by side, left to right. "
            "Each panel is a separate cinematic still from the SAME scene "
            "at two sequential moments. "
            "Panels must share the SAME environment, SAME lighting, SAME "
            "color palette, and SAME camera angle. The corridor walls, rust "
            "texture, pipe placement, and lighting direction must be identical "
            "across both panels — only the action changes. "
            "Separate panels with a thin black vertical divider line."
        )

    prompt = (
        f"{layout_instruction}\n\n"
        f"SHARED SCENE CONTEXT (applies to ALL panels):\n"
        f"Location: {location}\n"
        f"Camera: {camera}\n"
        f"Lighting: {lighting}\n"
        f"Atmosphere: {atmosphere}\n"
        f"Technical: {cinematic}\n"
        f"{action_block}"
        f"{camera_guard}"
        f"{env_block}\n"
        f"PANEL DESCRIPTIONS:\n\n"
        f"{panel_section}\n"
        f"Photorealistic. No text overlays. No labels. No panel numbers. "
        f"No comic style. No illustration. Pure cinematic stills in a strip."
    )

    return prompt


def build_parts_for_shot(prompt_text: str, is_env: bool,
                         scene_ref_part, char_ref_parts: list,
                         jinx_desc: str, types_module) -> list:
    """Assemble parts with explicit labels between reference images.

    Labels tell NBP exactly what each reference image is for:
      SCENE REFERENCE → environment/corridor to match
      CHARACTER REFERENCE → identity to preserve
    """
    parts = []

    # Scene reference (labeled)
    if scene_ref_part is not None and not is_env:
        parts.append(types_module.Part(text=(
            "SCENE/ENVIRONMENT REFERENCE — this image shows the corridor "
            "environment. Match the walls, rust, pipes, lighting color, "
            "and atmosphere in your output:"
        )))
        parts.append(scene_ref_part)

    # Character references (labeled)
    if not is_env and char_ref_parts:
        parts.append(types_module.Part(text=(
            f"CHARACTER REFERENCE IMAGES — these show JINX. "
            f"Ignore the backgrounds in these reference images. "
            f"Only use them for her face, hair, body type, and identity. "
            f"Visual description: {jinx_desc}"
        )))
        parts.extend(char_ref_parts)

    # Main prompt
    parts.append(types_module.Part(text=prompt_text))

    return parts


# ── Generation & Splitting ────────────────────────────────────────────

def generate_triptych(client, model: str, parts: list,
                      output_path: Path) -> dict:
    """Call NBP to generate a triptych strip and save the raw result."""
    from google.genai import types

    t0 = time.time()
    try:
        response = client.models.generate_content(
            model=model,
            contents=[types.Content(parts=parts)],
            config=types.GenerateContentConfig(
                response_modalities=["IMAGE", "TEXT"],
                temperature=1.0,
                image_config=types.ImageConfig(
                    aspect_ratio="21:9",
                ),
            ),
        )
    except Exception as e:
        elapsed = time.time() - t0
        return {"success": False, "elapsed": elapsed, "error": str(e)[:300]}

    elapsed = time.time() - t0

    if response.candidates:
        for part in response.candidates[0].content.parts:
            if (hasattr(part, "inline_data") and part.inline_data
                    and part.inline_data.mime_type.startswith("image/")):
                output_path.parent.mkdir(parents=True, exist_ok=True)
                output_path.write_bytes(part.inline_data.data)
                return {"success": True, "elapsed": elapsed,
                        "output": str(output_path),
                        "image_data": part.inline_data.data}

    text_parts = []
    if response.candidates:
        for part in response.candidates[0].content.parts:
            if hasattr(part, "text") and part.text:
                text_parts.append(part.text)
    error_msg = "; ".join(text_parts) if text_parts else "No image in response"
    return {"success": False, "elapsed": elapsed, "error": error_msg[:300]}


def split_triptych(image_data: bytes, output_dir: Path, base_name: str,
                   panel_count: int = 3) -> list[Path]:
    """Split a triptych strip into individual panel images.

    Equal-width splitting — simple and reliable.
    """
    from PIL import Image

    img = Image.open(BytesIO(image_data))
    w = img.width
    panel_w = w // panel_count

    suffixes = ["f1", "f2", "f3"][:panel_count]
    saved = []
    for i, suffix in enumerate(suffixes):
        left = i * panel_w
        right = (i + 1) * panel_w if i < panel_count - 1 else w
        panel = img.crop((left, 0, right, img.height))
        out_path = output_dir / f"{base_name}_{suffix}.png"
        panel.save(out_path, "PNG")
        saved.append(out_path)
        print(f"    Panel {suffix}: {right - left}x{img.height}px → {out_path.name}")

    return saved


def extract_hero_panel(image_data: bytes, panel_count: int = 3) -> bytes:
    """Extract the hero (center) panel from a triptych strip as PNG bytes.

    For 3-panel: center panel. For 2-panel: right panel (the hero).
    """
    from PIL import Image

    img = Image.open(BytesIO(image_data))
    panel_w = img.width // panel_count

    if panel_count == 3:
        # Center panel
        hero = img.crop((panel_w, 0, 2 * panel_w, img.height))
    else:
        # Right panel (panel 2 is the hero for 2-panel)
        hero = img.crop((panel_w, 0, img.width, img.height))

    buf = BytesIO()
    hero.save(buf, "PNG")
    return buf.getvalue()


# ── Utilities ─────────────────────────────────────────────────────────

def parse_shot_range(s: str) -> list[int]:
    if "-" in s:
        lo, hi = s.split("-", 1)
        return list(range(int(lo), int(hi) + 1))
    return [int(x.strip()) for x in s.split(",")]


# ── Main ──────────────────────────────────────────────────────────────

def main():
    parser = argparse.ArgumentParser(
        description="Test NBP direct rendering — V4 triptych + scene refs")
    parser.add_argument("project", help="Project path (e.g. leviathan/)")
    parser.add_argument("--episode", type=int, default=1)
    parser.add_argument("--shots", default="1-3",
                        help="Shot range (e.g. 1-3 or 1,2,5)")
    args = parser.parse_args()

    project_path = Path(args.project).resolve()
    if not project_path.exists():
        print(f"ERROR: Project path not found: {project_path}")
        sys.exit(1)

    api_key = os.environ.get("GOOGLE_API_KEY")
    if not api_key:
        print("ERROR: GOOGLE_API_KEY environment variable not set")
        sys.exit(1)

    from google import genai
    from google.genai import types

    client = genai.Client(api_key=api_key)

    # Load data
    storyboard = load_storyboard(project_path, args.episode)
    breakdown = load_breakdown(project_path)
    shot_ids = parse_shot_range(args.shots)

    jinx_desc = breakdown["characters"]["JINX"]["visual_description"]
    print(f"\nJinx visual desc: {jinx_desc[:80]}...")
    print(f"Model: {MODEL_PRO}")
    print(f"Strategy: triptych per shot + scene reference chaining")
    print(f"  Shot 1 (ENV) → establishes corridor → hero panel becomes")
    print(f"  scene reference for shots 2, 3, etc.\n")

    # Pre-load character reference images
    print("Loading character reference images...")
    char_ref_parts = load_reference_images(project_path)
    if not char_ref_parts:
        print("ERROR: No reference images loaded")
        sys.exit(1)
    print(f"  {len(char_ref_parts)} character refs loaded\n")

    # Output directory
    output_dir = project_path / "storyboards" / "assets" / f"ep_{args.episode:03d}"
    output_dir.mkdir(parents=True, exist_ok=True)

    cost_per_image = 0.134
    shots_by_id = {s["id"]: s for s in storyboard["shots"]}
    results = []
    total_cost = 0.0

    # Scene reference — built from first ENV shot's hero panel
    scene_ref_part = None

    for shot_id in shot_ids:
        shot = shots_by_id.get(shot_id)
        if not shot:
            print(f"WARNING: Shot {shot_id} not found in storyboard")
            continue

        has_characters = len(shot.get("characters_in_shot", [])) > 0
        is_env = not has_characters
        has_last = bool(shot.get("last_frame", "").strip())
        panel_count = 3 if has_last else 2

        print(f"\n{'='*60}")
        print(f"SHOT {shot_id}: {shot['name']}")
        print(f"  Type: {shot['shot_type']} | Angle: {shot['camera_angle']}"
              f" | {'ENV' if is_env else 'CHAR: ' + str(shot.get('characters_in_shot', []))}")
        print(f"  Panels: {panel_count} | ENV sanitize: {'ON' if is_env else 'OFF'}"
              f" | Scene ref: {'YES' if scene_ref_part else 'NO'}")
        print(f"{'='*60}")

        # Build triptych prompt
        prompt = build_triptych_prompt(
            shot, storyboard, is_env, jinx_desc)
        if not prompt:
            print("  SKIP: no hero_frame text")
            continue

        # Assemble parts with labeled references
        parts = build_parts_for_shot(
            prompt, is_env, scene_ref_part, char_ref_parts,
            jinx_desc, types)

        ref_summary = []
        if scene_ref_part and not is_env:
            ref_summary.append("1 scene")
        if has_characters:
            ref_summary.append(f"{len(char_ref_parts)} character")
        print(f"  References: {' + '.join(ref_summary) if ref_summary else 'none'}")

        # Generate triptych strip
        strip_name = f"LEV_EP{args.episode:03d}_S{shot_id:02d}_NBP_strip.png"
        strip_path = output_dir / strip_name

        print(f"  Generating {panel_count}-panel triptych ...", end=" ", flush=True)
        result = generate_triptych(client, MODEL_PRO, parts, strip_path)

        if not result["success"]:
            print(f"FAIL ({result['elapsed']:.1f}s): {result['error']}")
            results.append({"shot_id": shot_id, "type": "triptych", **result})
            continue

        print(f"OK ({result['elapsed']:.1f}s)")
        total_cost += cost_per_image

        # If this is an ENV shot and we don't have a scene ref yet,
        # extract the hero panel as the scene reference for subsequent shots
        if is_env and scene_ref_part is None:
            hero_bytes = extract_hero_panel(result["image_data"], panel_count)
            scene_ref_part = types.Part(inline_data=types.Blob(
                mime_type="image/png", data=hero_bytes
            ))
            print(f"  → Captured hero panel as SCENE REFERENCE "
                  f"({len(hero_bytes) / 1024:.0f} KB)")

        # Split into individual panels
        print(f"  Splitting strip...")
        base = f"LEV_EP{args.episode:03d}_S{shot_id:02d}_NBP"
        panels = split_triptych(result["image_data"], output_dir, base,
                                panel_count)

        results.append({
            "shot_id": shot_id,
            "type": "triptych",
            "success": True,
            "elapsed": result["elapsed"],
            "strip": strip_name,
            "panels": [p.name for p in panels],
        })

    # Summary
    print(f"\n{'='*60}")
    print("SUMMARY")
    print(f"{'='*60}")
    successes = sum(1 for r in results if r.get("success"))
    failures = sum(1 for r in results if not r.get("success"))
    total_panels = sum(len(r.get("panels", [])) for r in results)
    print(f"  Triptychs: {successes}/{len(results)}")
    print(f"  Panels:    {total_panels}")
    print(f"  Failed:    {failures}")
    print(f"  Est. cost: ${total_cost:.2f}")
    print(f"  Output:    {output_dir}/")
    print(f"\n  V4 improvements:")
    print(f"    Scene ref:      Shot 1 hero → scene ref for shots 2, 3")
    print(f"    Camera guard:   No direct-to-camera (invisible observer)")
    print(f"    Blocking:       Dropped storyboard facing data (was broken)")
    print(f"    Location:       Storyboard location field injected into prompt")

    if failures:
        print("\n  FAILURES:")
        for r in results:
            if not r.get("success"):
                print(f"    Shot {r['shot_id']}: {r.get('error', 'unknown')}")

    # Open strip images in Preview for review
    if successes > 0 and sys.platform == "darwin":
        strip_files = [output_dir / r["strip"] for r in results
                       if r.get("success") and r.get("strip")]
        if strip_files:
            cmd = "open " + " ".join(f'"{p}"' for p in strip_files)
            os.system(cmd)


if __name__ == "__main__":
    main()
