#!/usr/bin/env python3
"""
generate_rough_frames.py — Generate rough visualization frames from storyboard JSON

Reads a storyboard JSON file and generates rough frame images for each shot
using Gemini API. Includes available reference images for consistency.

Usage:
    python3 generate_rough_frames.py leviathan/ --episode 1
    python3 generate_rough_frames.py leviathan/ --episode 1 --shots 1-5
    python3 generate_rough_frames.py leviathan/ --episode 1 --skip-existing
    python3 generate_rough_frames.py leviathan/ --episode 1 --dry-run
    python3 generate_rough_frames.py leviathan/ --episode 1 --model gemini-2.0-flash-exp

Env vars:
    GOOGLE_API_KEY — Gemini API key (required)

Dependencies:
    pip install google-genai Pillow
"""

import argparse
import json
import os
import sys
import time
from pathlib import Path
from typing import List, Optional, Tuple


def resolve_ref_image(project_dir: Path, ref_path: str) -> Optional[Path]:
    """Resolve a reference image path from the storyboard.

    Tries multiple resolution strategies:
    1. Direct path under visual/
    2. Character heroes folder
    3. Prop folders
    """
    visual_dir = project_dir / "visual"

    # Try direct path
    direct = visual_dir / ref_path
    if direct.exists():
        return direct

    # Try without leading refs/
    if ref_path.startswith("refs/"):
        direct2 = visual_dir / ref_path
        if direct2.exists():
            return direct2

    return None


def find_character_hero(project_dir: Path, char_name: str) -> Optional[Path]:
    """Find the hero image for a character."""
    heroes_dir = project_dir / "visual" / "refs" / "characters" / "heroes"
    if not heroes_dir.exists():
        return None

    # Try various case variants
    variants = [
        char_name,                    # jinx
        char_name.capitalize(),       # Jinx
        char_name.upper(),            # JINX
        char_name.title(),            # Jinx
    ]
    for name_var in variants:
        for pattern in [f"{name_var}_Hero.*", f"{name_var}_hero.*", f"{name_var}.*"]:
            matches = list(heroes_dir.glob(pattern))
            # Filter out non-image files
            matches = [m for m in matches if m.suffix.lower() in [".png", ".jpg", ".jpeg", ".webp"]]
            if matches:
                # Prefer _Hero variant
                hero_matches = [m for m in matches if "hero" in m.stem.lower()]
                return hero_matches[0] if hero_matches else matches[0]
    return None


def find_location_hero(project_dir: Path) -> Optional[Path]:
    """Find the lower decks hero image."""
    loc_dir = project_dir / "visual" / "refs" / "locations"
    hero = loc_dir / "lower_decks_hero.png"
    if hero.exists():
        return hero
    return None


def build_prompt_with_refs(shot: dict, storyboard: dict, project_dir: Path) -> Tuple[str, List[Tuple[Path, str]]]:
    """Build generation prompt and collect reference images for a shot.

    Returns:
        (prompt_text, [(image_path, role_description), ...])
    """
    ref_images = []

    # Build searchable text from all shot fields
    search_text = " ".join([
        shot.get("first_frame", ""),
        shot.get("subject", ""),
        shot.get("name", ""),
        shot.get("script_excerpt", ""),
        shot.get("action", ""),
    ]).lower()

    characters_in_shot = []

    storyboard_chars = storyboard.get("characters", {})
    for char_name, char_data in storyboard_chars.items():
        # Check if character is referenced anywhere in the shot
        if char_name.lower() in search_text:
            characters_in_shot.append(char_name)

    # Also check reference_slots for character refs
    ref_slots = shot.get("generation_metadata", {}).get("reference_slots", {})
    for slot_id, ref_path in ref_slots.items():
        if "characters" in ref_path.lower():
            parts = ref_path.split("/")
            for i, p in enumerate(parts):
                if p.upper() in [c.upper() for c in storyboard_chars.keys()]:
                    if p.lower() not in [c.lower() for c in characters_in_shot]:
                        characters_in_shot.append(p.lower())

    # Add character hero images
    for char_name in characters_in_shot:
        hero = find_character_hero(project_dir, char_name)
        if hero:
            ref_images.append((hero, f"Character reference for {char_name}"))

    # Add location hero for any shot (helps with environment consistency)
    loc_hero = find_location_hero(project_dir)
    if loc_hero and shot.get("shot_type") in ["WIDE", "LS", "EWS", "MS", "POV"]:
        ref_images.append((loc_hero, "Location/environment reference"))

    # Build the prompt
    prompt_parts = []

    # Cinematic context from storyboard
    prompt_parts.append(
        f"Generate a rough storyboard frame for a cinematic microdrama. "
        f"Style: {storyboard.get('cinematic', 'Photorealistic, cinematic lighting')}. "
        f"9:16 vertical aspect ratio."
    )

    # Shot-specific info
    prompt_parts.append(f"\nShot type: {shot.get('shot_type', 'MS')}")
    prompt_parts.append(f"Camera angle: {shot.get('camera_angle', 'eye level')}")
    prompt_parts.append(f"Lens: {shot.get('focal_length', '50mm')} {shot.get('aperture', 'f/2.0')}")
    prompt_parts.append(f"Emotion: {shot.get('emotion', '')}")
    prompt_parts.append(f"Lighting: {shot.get('lighting', '')}")

    # Color palette
    palette = shot.get("color_palette", [])
    if palette:
        prompt_parts.append(f"Color palette: {', '.join(palette)}")

    # The main frame description
    prompt_parts.append(f"\nFrame description:\n{shot['first_frame']}")

    prompt = "\n".join(prompt_parts)
    return prompt, ref_images


def generate_frame(
    client, types_module, model: str,
    prompt: str, ref_images: List[Tuple[Path, str]],
    output_path: Path, aspect_ratio: str = "9:16"
) -> Tuple[bool, str]:
    """Generate a single frame using Gemini API.

    Returns:
        (success, error_or_empty)
    """
    parts = []

    # Add reference images
    for img_path, role in ref_images:
        try:
            img_bytes = img_path.read_bytes()
            mime = "image/png" if img_path.suffix.lower() == ".png" else "image/jpeg"
            parts.append(types_module.Part(
                inline_data=types_module.Blob(mime_type=mime, data=img_bytes)
            ))
            parts.append(types_module.Part(text=f"[{role}]"))
        except Exception as e:
            print(f"  WARNING: Could not load ref {img_path}: {e}")

    # Add the prompt
    parts.append(types_module.Part(text=prompt))

    # Map aspect ratio
    gemini_aspect = {"9:16": "9:16", "16:9": "16:9", "1:1": "1:1"}.get(aspect_ratio, "9:16")

    try:
        response = client.models.generate_content(
            model=model,
            contents=parts,
            config=types_module.GenerateContentConfig(
                response_modalities=["IMAGE", "TEXT"],
                image_config=types_module.ImageConfig(
                    aspect_ratio=gemini_aspect,
                ),
            ),
        )

        # Extract image
        if response.candidates:
            for part in response.candidates[0].content.parts:
                if part.inline_data and part.inline_data.mime_type.startswith("image/"):
                    output_path.parent.mkdir(parents=True, exist_ok=True)
                    output_path.write_bytes(part.inline_data.data)
                    return True, ""

        # Check for text-only response (safety filter or no image)
        text_parts = []
        if response.candidates:
            for part in response.candidates[0].content.parts:
                if hasattr(part, 'text') and part.text:
                    text_parts.append(part.text)

        if text_parts:
            return False, f"No image generated. Response: {' '.join(text_parts)[:200]}"
        return False, "No image in response"

    except Exception as e:
        return False, str(e)


def main():
    parser = argparse.ArgumentParser(description="Generate rough frames from storyboard JSON")
    parser.add_argument("project_dir", help="Project directory (e.g., leviathan/)")
    parser.add_argument("--episode", "-e", type=int, required=True, help="Episode number")
    parser.add_argument("--shots", help="Shot range (e.g., '1-5' or '3,7,12')")
    parser.add_argument("--skip-existing", action="store_true", help="Skip shots that already have frames")
    parser.add_argument("--dry-run", action="store_true", help="Show what would be generated without calling API")
    parser.add_argument("--model", default="gemini-2.5-flash-image", help="Gemini model name")
    parser.add_argument("--delay", type=float, default=5.0, help="Seconds between API calls (rate limiting)")
    parser.add_argument("--frame-type", choices=["first", "last", "both"], default="first",
                        help="Which frame(s) to generate per shot")
    args = parser.parse_args()

    # Resolve project directory
    project_dir = Path(args.project_dir).resolve()
    if not project_dir.exists():
        # Try relative to script location
        engine_dir = Path(__file__).resolve().parent.parent.parent
        project_dir = engine_dir / args.project_dir
    if not project_dir.exists():
        print(f"ERROR: Project directory not found: {args.project_dir}", file=sys.stderr)
        sys.exit(1)

    # Load storyboard
    ep_str = f"{args.episode:03d}"
    storyboard_path = project_dir / "storyboards" / f"storyboard_ep_{ep_str}.json"
    if not storyboard_path.exists():
        print(f"ERROR: Storyboard not found: {storyboard_path}", file=sys.stderr)
        sys.exit(1)

    try:
        with open(storyboard_path) as f:
            storyboard = json.load(f)
    except json.JSONDecodeError as e:
        print(f"ERROR: Invalid JSON in {storyboard_path}: {e}", file=sys.stderr)
        sys.exit(1)

    shots = storyboard.get("shots", [])
    if not shots:
        print("ERROR: No shots in storyboard", file=sys.stderr)
        sys.exit(1)

    # Filter shots if range specified
    if args.shots:
        if "-" in args.shots:
            start, end = map(int, args.shots.split("-"))
            shots = [s for s in shots if start <= s["id"] <= end]
        elif "," in args.shots:
            ids = set(map(int, args.shots.split(",")))
            shots = [s for s in shots if s["id"] in ids]
        else:
            shot_id = int(args.shots)
            shots = [s for s in shots if s["id"] == shot_id]

    # Output directory
    output_dir = project_dir / "storyboards" / "rough_frames" / f"ep_{ep_str}"
    output_dir.mkdir(parents=True, exist_ok=True)

    print(f"╔══════════════════════════════════════════════════════════╗")
    print(f"║  ROUGH FRAME GENERATOR — Episode {args.episode}                    ║")
    print(f"╠══════════════════════════════════════════════════════════╣")
    print(f"║  Storyboard: {storyboard_path.name:<42} ║")
    print(f"║  Shots: {len(shots):<47} ║")
    print(f"║  Output: {str(output_dir.relative_to(project_dir)):<46} ║")
    print(f"║  Model: {args.model:<47} ║")
    print(f"║  Frame type: {args.frame_type:<42} ║")
    print(f"╚══════════════════════════════════════════════════════════╝")
    print()

    if args.dry_run:
        print("DRY RUN — showing what would be generated:\n")
        for shot in shots:
            prompt, refs = build_prompt_with_refs(shot, storyboard, project_dir)
            ref_names = [f"{r[1]} ({r[0].name})" for r in refs]
            print(f"  Shot {shot['id']:2d}: {shot['name']}")
            print(f"           Type: {shot['shot_type']} | Emotion: {shot['emotion']}")
            print(f"           Refs: {', '.join(ref_names) if ref_names else 'none'}")
            print(f"           Prompt: {shot['first_frame'][:80]}...")
            print()
        sys.exit(0)

    # Initialize Gemini
    api_key = os.environ.get("GOOGLE_API_KEY")
    if not api_key:
        print("ERROR: GOOGLE_API_KEY not set", file=sys.stderr)
        sys.exit(1)

    try:
        from google import genai
        from google.genai import types
    except ImportError:
        print("ERROR: google-genai not installed. Run: pip install google-genai", file=sys.stderr)
        sys.exit(1)

    client = genai.Client(api_key=api_key)

    # Generate frames
    results = {"success": 0, "failed": 0, "skipped": 0}

    for i, shot in enumerate(shots):
        shot_id = shot["id"]
        output_file = output_dir / f"shot_{shot_id:02d}_first.png"

        if args.skip_existing and output_file.exists():
            print(f"  [{i+1}/{len(shots)}] Shot {shot_id:2d}: {shot['name']} — SKIPPED (exists)")
            results["skipped"] += 1
            continue

        print(f"  [{i+1}/{len(shots)}] Shot {shot_id:2d}: {shot['name']}")
        print(f"           {shot['shot_type']} | {shot['emotion']} | {shot['focal_length']}")

        prompt, ref_images = build_prompt_with_refs(shot, storyboard, project_dir)
        ref_names = [r[0].name for r in ref_images]
        if ref_names:
            print(f"           Refs: {', '.join(ref_names)}")

        success, error = generate_frame(
            client, types, args.model,
            prompt, ref_images,
            output_file,
            shot.get("aspect", "9:16")
        )

        if success:
            print(f"           ✓ Saved: {output_file.name}")
            results["success"] += 1
        else:
            print(f"           ✗ FAILED: {error[:100]}")
            results["failed"] += 1

        # Also generate last_frame if requested
        if args.frame_type in ["last", "both"] and shot.get("last_frame"):
            last_output = output_dir / f"shot_{shot_id:02d}_last.png"
            if args.skip_existing and last_output.exists():
                print(f"           Last frame: SKIPPED (exists)")
            else:
                # Build last frame prompt (swap first_frame for last_frame)
                last_prompt = prompt.replace(shot["first_frame"], shot["last_frame"])
                time.sleep(args.delay)

                success_last, error_last = generate_frame(
                    client, types, args.model,
                    last_prompt, ref_images,
                    last_output,
                    shot.get("aspect", "9:16")
                )
                if success_last:
                    print(f"           ✓ Last frame: {last_output.name}")
                    results["success"] += 1
                else:
                    print(f"           ✗ Last frame FAILED: {error_last[:100]}")
                    results["failed"] += 1

        # Rate limiting
        if i < len(shots) - 1:
            print(f"           (waiting {args.delay}s...)")
            time.sleep(args.delay)

    # Summary
    print()
    print(f"═══════════════════════════════════════════════════════════")
    print(f"  RESULTS: {results['success']} generated, {results['failed']} failed, {results['skipped']} skipped")
    print(f"  Output: {output_dir}")
    print(f"═══════════════════════════════════════════════════════════")

    # Write manifest for the storyboard editor
    manifest = {
        "episode": args.episode,
        "storyboard": storyboard_path.name,
        "model": args.model,
        "generated_at": time.strftime("%Y-%m-%dT%H:%M:%S"),
        "frames": {}
    }
    for shot in storyboard.get("shots", []):
        shot_id = shot["id"]
        first_path = output_dir / f"shot_{shot_id:02d}_first.png"
        last_path = output_dir / f"shot_{shot_id:02d}_last.png"
        manifest["frames"][str(shot_id)] = {
            "name": shot["name"],
            "first_frame": str(first_path.relative_to(project_dir)) if first_path.exists() else None,
            "last_frame": str(last_path.relative_to(project_dir)) if last_path.exists() else None,
        }

    manifest_path = output_dir / "manifest.json"
    with open(manifest_path, "w") as f:
        json.dump(manifest, f, indent=2)
    print(f"  Manifest: {manifest_path}")

    sys.exit(0 if results["failed"] == 0 else 1)


if __name__ == "__main__":
    main()
