#!/usr/bin/env python3
"""
test_gemini_swap.py — Test Gemini 2.5 Flash wardrobe swap with 3 prompt strategies.

Tests:
  A) Cropped reference  — crop torso/clothing from hero (no face), explicit swap prompt
  B) Text-only          — no reference image, just detailed wardrobe description
  C) Flat-lay first     — generate flat-lay of wardrobe, then use as reference

Uses one shot from the comprehensive test as the source image.

Usage:
    python3 test_gemini_swap.py leviathan/ --episode 1 --character jinx
    python3 test_gemini_swap.py leviathan/ --episode 1 --character jinx --dry-run

Env vars:
    FAL_KEY — fal.ai API key (required)
"""

import argparse
import json
import sys
import time
from datetime import datetime
from pathlib import Path
from typing import Optional

# Add tools to path
_tools_dir = str(Path(__file__).resolve().parent)
if _tools_dir not in sys.path:
    sys.path.insert(0, _tools_dir)


METHODS = {
    "a_cropped_ref": {
        "label": "Cropped Clothing Reference",
        "short": "Cropped Ref",
        "description": "Hero frame cropped to torso only (no face). Explicit swap prompt.",
        "color": "#00e5ff",
    },
    "b_text_only": {
        "label": "Text-Only Description",
        "short": "Text Only",
        "description": "No reference image. Wardrobe described in text from breakdown data.",
        "color": "#b388ff",
    },
    "c_flat_lay": {
        "label": "Flat-Lay Then Apply",
        "short": "Flat-Lay",
        "description": "Generate flat-lay product photo of wardrobe, then use as reference.",
        "color": "#69f0ae",
    },
}

METHOD_KEYS = list(METHODS.keys())


def get_image_url(result: dict) -> Optional[str]:
    """Extract image URL from fal.ai result."""
    images = result.get("images", [])
    if images:
        return images[0].get("url")
    return None


def download_image(url: str, path: Path) -> bool:
    """Download image from URL to local path."""
    import requests
    try:
        r = requests.get(url, timeout=60)
        r.raise_for_status()
        path.write_bytes(r.content)
        return True
    except Exception as e:
        print(f"    Download failed: {e}")
        return False


def crop_hero_to_clothing(hero_url: str, output_path: Path) -> str:
    """Download hero and crop to torso/clothing region (no face).

    Crops from 25% to 85% of height (torso + legs, excludes head).
    Returns CDN URL after uploading cropped version.
    """
    import requests
    from io import BytesIO

    try:
        from PIL import Image
    except ImportError:
        print("    ERROR: Pillow not installed. pip install Pillow")
        sys.exit(1)

    # Download hero
    r = requests.get(hero_url, timeout=60)
    r.raise_for_status()
    img = Image.open(BytesIO(r.content))

    w, h = img.size
    # Crop: skip top 20% (face), keep to 90% (feet)
    top = int(h * 0.20)
    bottom = int(h * 0.90)
    cropped = img.crop((0, top, w, bottom))
    cropped.save(str(output_path))

    # Upload to fal.ai for use as reference
    import fal_client
    uploaded_url = fal_client.upload_file(str(output_path))
    return uploaded_url


def generate_flat_lay(wardrobe_desc: str) -> dict:
    """Generate a flat-lay product photo of the wardrobe using Gemini."""
    import fal_client

    prompt = (
        f"Generate a flat-lay product photo of the following clothing and accessories "
        f"arranged neatly on a dark matte surface, shot from directly above:\n\n"
        f"{wardrobe_desc}\n\n"
        f"Top-down view, soft studio lighting, each item clearly separated and visible. "
        f"Show fabric texture and material details. No person, mannequin, or body — "
        f"just the clothing items laid flat. Photorealistic product photography."
    )

    return fal_client.subscribe("fal-ai/gemini-25-flash-image/edit", arguments={
        "prompt": prompt,
        "image_urls": [],
        "num_images": 1,
        "aspect_ratio": "1:1",
        "output_format": "png",
    })


def generate_flat_lay_generate(wardrobe_desc: str) -> dict:
    """Generate flat-lay using Gemini's generate endpoint (no input images needed)."""
    import fal_client

    prompt = (
        f"Generate a flat-lay product photo of the following clothing and accessories "
        f"arranged neatly on a dark matte surface, shot from directly above:\n\n"
        f"{wardrobe_desc}\n\n"
        f"Top-down view, soft studio lighting, each item clearly separated and visible. "
        f"Show fabric texture and material details. No person, mannequin, or body — "
        f"just the clothing items laid flat. Photorealistic product photography."
    )

    return fal_client.subscribe("fal-ai/gemini-25-flash-image", arguments={
        "prompt": prompt,
        "num_images": 1,
        "aspect_ratio": "1:1",
        "output_format": "png",
    })


def gemini_swap_cropped(source_url: str, cropped_ref_url: str) -> dict:
    """Method A: Swap wardrobe using cropped clothing reference."""
    import fal_client

    prompt = (
        "Using Image 1 (the character in a scene) and Image 2 (a close-up of the "
        "target clothing and gear):\n\n"
        "Change ONLY the character's clothing and accessories in Image 1 to match "
        "the outfit shown in Image 2.\n\n"
        "PRESERVE EXACTLY:\n"
        "- Character's face, facial expression, and hair\n"
        "- Body pose and proportions\n"
        "- Background environment, scene, and setting\n"
        "- Lighting direction, color temperature, and shadows\n"
        "- Camera angle and framing\n\n"
        "CHANGE ONLY:\n"
        "- Clothing and accessories to match Image 2\n"
        "- Ensure realistic fabric folds and shadows that match Image 1's lighting\n\n"
        "The result should look like the exact same photograph but with different clothes."
    )

    return fal_client.subscribe("fal-ai/gemini-25-flash-image/edit", arguments={
        "prompt": prompt,
        "image_urls": [source_url, cropped_ref_url],
        "num_images": 1,
        "output_format": "png",
    })


def gemini_swap_text_only(source_url: str, wardrobe_desc: str) -> dict:
    """Method B: Swap wardrobe using text description only, no reference image."""
    import fal_client

    prompt = (
        f"Edit the character's clothing in this image. Change their outfit to:\n\n"
        f"{wardrobe_desc}\n\n"
        f"PRESERVE EXACTLY:\n"
        f"- Character's face, facial expression, and hair\n"
        f"- Body pose and proportions\n"
        f"- Background environment and setting\n"
        f"- Lighting and shadows\n"
        f"- Camera angle and framing\n\n"
        f"CHANGE ONLY the clothing and accessories as described above. "
        f"Ensure realistic fabric texture, folds, and shadows that match "
        f"the existing lighting in the image."
    )

    return fal_client.subscribe("fal-ai/gemini-25-flash-image/edit", arguments={
        "prompt": prompt,
        "image_urls": [source_url],
        "num_images": 1,
        "output_format": "png",
    })


def gemini_swap_flat_lay(source_url: str, flat_lay_url: str) -> dict:
    """Method C: Swap wardrobe using flat-lay reference."""
    import fal_client

    prompt = (
        "Using Image 1 (a character in a scene) and Image 2 (a flat-lay product "
        "photo of clothing items):\n\n"
        "Dress the character in Image 1 with the clothing items shown in Image 2.\n\n"
        "PRESERVE EXACTLY:\n"
        "- Character's face, facial expression, and hair\n"
        "- Body pose and proportions\n"
        "- Background environment and setting\n"
        "- Lighting and shadows\n"
        "- Camera angle and framing\n\n"
        "CHANGE ONLY:\n"
        "- Put the clothing from Image 2 onto the character\n"
        "- Match fabric colors, textures, and patterns from Image 2\n"
        "- Ensure realistic fit, folds, and shadows matching Image 1's lighting\n\n"
        "The result should look like the same photograph but wearing the clothes "
        "from the flat-lay."
    )

    return fal_client.subscribe("fal-ai/gemini-25-flash-image/edit", arguments={
        "prompt": prompt,
        "image_urls": [source_url, flat_lay_url],
        "num_images": 1,
        "output_format": "png",
    })


def build_html(results: list, hero_url: str, output_dir: Path,
               episode: int, character: str) -> Path:
    """Build comparison HTML."""

    html = f"""<!DOCTYPE html>
<html><head><meta charset="UTF-8">
<title>Gemini Swap Test — EP{episode:03d} — {character.title()}</title>
<style>
  * {{ box-sizing: border-box; margin: 0; padding: 0; }}
  body {{ background: #0a0a0f; color: #c8c8d4; font-family: 'SF Mono', 'Fira Code', monospace; }}
  .header {{ padding: 24px 32px; border-bottom: 1px solid #1a1a28; background: linear-gradient(180deg, #0e0e16, #0a0a0f); }}
  .header h1 {{ font-size: 13px; letter-spacing: 4px; text-transform: uppercase; color: #00e5ff; margin-bottom: 8px; }}
  .header .sub {{ font-size: 11px; color: #505068; line-height: 1.6; }}
  .section {{ padding: 20px 32px; border-bottom: 1px solid #1a1a28; }}
  .section h2 {{ font-size: 11px; letter-spacing: 3px; text-transform: uppercase; color: #b388ff; margin-bottom: 12px; }}
  .refs {{ display: flex; gap: 16px; align-items: flex-start; flex-wrap: wrap; }}
  .ref-card {{ background: #0e0e16; border: 1px solid #1a1a28; border-radius: 6px; padding: 12px; }}
  .ref-card h3 {{ font-size: 9px; letter-spacing: 2px; text-transform: uppercase; color: #707088; margin-bottom: 8px; }}
  .ref-card img {{ height: 300px; width: auto; border-radius: 4px; cursor: pointer; }}
  .ref-card .meta {{ font-size: 9px; color: #404058; margin-top: 6px; }}
  .shots {{ padding: 16px 32px; }}
  .shot {{ background: #0e0e16; border: 1px solid #1a1a28; border-radius: 8px; margin-bottom: 20px; overflow: hidden; }}
  .shot-header {{ padding: 12px 16px; border-bottom: 1px solid #1a1a28; background: #101018; }}
  .shot-header h3 {{ font-size: 12px; color: #d0d0d8; }}
  .compare {{ display: grid; grid-template-columns: 1fr 1fr 1fr 1fr; gap: 1px; background: #1a1a28; }}
  .col {{ background: #0e0e16; padding: 10px; }}
  .col h4 {{ font-size: 9px; letter-spacing: 1.5px; text-transform: uppercase; margin-bottom: 6px; }}
  .col img {{ width: 100%; border-radius: 4px; border: 1px solid #1a1a28; cursor: pointer; }}
  .col img:hover {{ border-color: #00e5ff; }}
  .col .meta {{ font-size: 8px; color: #404058; margin-top: 4px; }}
  .error {{ color: #ef5350; font-size: 9px; padding: 16px; text-align: center; }}
  .lightbox {{ display: none; position: fixed; inset: 0; background: rgba(0,0,0,0.95); z-index: 100; align-items: center; justify-content: center; }}
  .lightbox.open {{ display: flex; }}
  .lightbox img {{ max-width: 95vw; max-height: 95vh; object-fit: contain; }}
  .lightbox .close {{ position: absolute; top: 16px; right: 20px; font-size: 28px; color: #505068; cursor: pointer; background: none; border: none; }}
</style>
</head>
<body>
<div class="header">
  <h1>Gemini Wardrobe Swap Test — EP{episode:03d}</h1>
  <div class="sub">{character.title()} | {datetime.now().strftime('%Y-%m-%d %H:%M')} | {len(results)} shots</div>
  <div class="sub" style="margin-top:6px; color:#404058;">
    Testing 3 Gemini swap strategies: cropped clothing ref, text-only description, flat-lay product photo.
  </div>
</div>
"""

    # Reference images section
    html += '<div class="section"><h2>Reference Materials</h2><div class="refs">\n'
    for r in results:
        if r.get("_refs"):
            for ref_name, ref_info in r["_refs"].items():
                if ref_info.get("filename"):
                    html += f"""<div class="ref-card">
  <h3>{ref_name}</h3>
  <img src="{ref_info['filename']}" onclick="openLB(this.src)">
  <div class="meta">{ref_info.get('note', '')}</div>
</div>\n"""
            break  # Only show refs once
    html += '</div></div>\n'

    # Shot comparisons
    html += '<div class="shots">\n'
    for r in results:
        html += f"""<div class="shot">
  <div class="shot-header"><h3>S{r['shot_id']:02d}: {r['name']}</h3></div>
  <div class="compare">
    <div class="col"><h4 style="color:#ffffff">Source (Turbo T2I)</h4>
"""
        if r.get("source_filename"):
            html += f'      <img src="{r["source_filename"]}" onclick="openLB(this.src)">\n'
        html += '    </div>\n'

        for mk in METHOD_KEYS:
            m = METHODS[mk]
            mr = r.get(mk, {})
            html += f'    <div class="col"><h4 style="color:{m["color"]}">{m["short"]}</h4>\n'
            if mr.get("error"):
                html += f'      <div class="error">{str(mr["error"])[:100]}</div>\n'
            elif mr.get("filename"):
                html += f'      <img src="{mr["filename"]}" onclick="openLB(this.src)">\n'
                html += f'      <div class="meta">{mr.get("elapsed", 0):.1f}s</div>\n'
            else:
                html += '      <div class="error">No result</div>\n'
            html += '    </div>\n'

        html += '  </div>\n</div>\n'

    html += """</div>
<div class="lightbox" id="lb" onclick="if(event.target===this)closeLB()">
  <button class="close" onclick="closeLB()">&times;</button>
  <img id="lb-img">
</div>
<script>
function openLB(src){document.getElementById('lb-img').src=src;document.getElementById('lb').classList.add('open');}
function closeLB(){document.getElementById('lb').classList.remove('open');}
document.addEventListener('keydown',e=>{if(e.key==='Escape')closeLB();});
</script>
</body></html>"""

    html_path = output_dir / f"gemini_swap_test_ep_{episode:03d}_{character}.html"
    html_path.write_text(html)
    return html_path


def main():
    parser = argparse.ArgumentParser(description="Test Gemini wardrobe swap strategies")
    parser.add_argument("project_dir")
    parser.add_argument("-e", "--episode", type=int, required=True)
    parser.add_argument("-c", "--character", required=True)
    parser.add_argument("--shots", help="Comma-separated shot IDs (default: 3,5,6)")
    parser.add_argument("--hero-url", help="Hero frame CDN URL")
    parser.add_argument("--dry-run", action="store_true")
    args = parser.parse_args()

    project_dir = Path(args.project_dir).resolve()
    if not project_dir.is_dir():
        project_dir = Path.cwd() / args.project_dir
    if not project_dir.is_dir():
        print(f"ERROR: Project directory not found: {args.project_dir}")
        sys.exit(1)

    char = args.character.lower()
    ep_str = str(args.episode).zfill(3)

    # Load storyboard for wardrobe description
    sb_path = project_dir / "storyboards" / f"storyboard_ep_{ep_str}.json"
    try:
        with open(sb_path) as f:
            storyboard = json.load(f)
    except json.JSONDecodeError as e:
        print(f"ERROR: Invalid JSON in {sb_path}: {e}")
        sys.exit(1)

    char_data = storyboard.get("characters", {}).get(char, {})
    char_visual = char_data.get("visual", "")

    # Wardrobe description for text-only method
    wardrobe_desc = (
        f"Worn patched cargo pants with reinforced knees, salvage harness with metal "
        f"clips across chest, heavy-duty tool belt with salvage hook worn smooth from "
        f"years of use, layered thermal undershirt visible at collar, amber-glowing debt "
        f"counter wrist strap on left wrist with gunmetal housing and scratched protective "
        f"lens, patched rebreather mask hanging around neck, fingerless work gloves with "
        f"orange-stained cuticles visible. All clothing is industrial, patched, stained "
        f"with rust and grease — lower-deck salvager aesthetic."
    )

    # Find source images from comprehensive test
    comp_dir = project_dir / "storyboards" / "assets" / f"ep_{ep_str}" / "comprehensive_test"
    comp_json = comp_dir / f"comprehensive_test_ep_{args.episode:03d}_{char}.json"

    if not comp_json.exists():
        print(f"ERROR: Run ab_test_comprehensive.py first to generate source images")
        sys.exit(1)

    try:
        with open(comp_json) as f:
            comp_results = json.load(f)
    except json.JSONDecodeError as e:
        print(f"ERROR: Invalid JSON in {comp_json}: {e}")
        sys.exit(1)

    # Get hero URL
    hero_url = args.hero_url or comp_results.get("hero_info", {}).get("url")
    if not hero_url:
        print("ERROR: No hero URL found. Provide --hero-url")
        sys.exit(1)

    # Select shots
    shot_ids = [3, 5, 6]
    if args.shots:
        shot_ids = [int(s) for s in args.shots.split(",")]

    # Get source CDN URLs from comprehensive test results
    source_urls = {}
    for r in comp_results.get("results", []):
        sid = r["shot_id"]
        if sid in shot_ids:
            a_data = r.get("a_turbo_t2i", {})
            if a_data.get("cdn_url"):
                source_urls[sid] = {
                    "cdn_url": a_data["cdn_url"],
                    "filename": a_data.get("filename"),
                    "name": r.get("name", ""),
                }

    if not source_urls:
        print("ERROR: No source images found in comprehensive test results")
        sys.exit(1)

    # Output dir
    output_dir = project_dir / "storyboards" / "assets" / f"ep_{ep_str}" / "gemini_swap_test"
    output_dir.mkdir(parents=True, exist_ok=True)

    print(f"{'=' * 65}")
    print(f"  GEMINI WARDROBE SWAP TEST — EP{args.episode:03d} — {char.upper()}")
    print(f"{'=' * 65}")
    print(f"  Shots: {list(source_urls.keys())}")
    print(f"  Methods: Cropped ref, Text-only, Flat-lay")
    print(f"  Hero URL: {hero_url[:50]}...")
    print(f"{'=' * 65}")

    if args.dry_run:
        print(f"\n  Would test {len(source_urls)} shots x 3 methods = {len(source_urls) * 3} Gemini calls")
        print(f"  + 1 flat-lay generation + 1 hero crop")
        print(f"  Est cost: ~${(len(source_urls) * 3 + 2) * 0.04:.2f}")
        sys.exit(0)

    # ── Step 1: Prepare references ────────────────────────────────────

    # Crop hero to clothing-only
    print(f"\n  Cropping hero to clothing region...", end="", flush=True)
    crop_path = output_dir / f"hero_cropped_{char}.png"
    cropped_url = crop_hero_to_clothing(hero_url, crop_path)
    print(f" OK")

    # Generate flat-lay
    print(f"  Generating flat-lay of wardrobe...", end="", flush=True)
    t0 = time.time()
    try:
        flat_result = generate_flat_lay_generate(wardrobe_desc)
        flat_elapsed = time.time() - t0
        flat_url = get_image_url(flat_result)
        if flat_url:
            flat_path = output_dir / f"flat_lay_{char}.png"
            download_image(flat_url, flat_path)
            print(f" OK ({flat_elapsed:.1f}s)")
        else:
            print(f" no image returned")
            flat_url = None
    except Exception as e:
        flat_elapsed = time.time() - t0
        print(f" ERROR: {e}")
        flat_url = None

    # Copy source images from comprehensive test
    for sid, info in source_urls.items():
        src = comp_dir / info["filename"]
        dst = output_dir / info["filename"]
        if src.exists() and not dst.exists():
            import shutil
            shutil.copy2(src, dst)

    refs_info = {
        "Cropped Hero (no face)": {
            "filename": f"hero_cropped_{char}.png",
            "note": "Top 20% and bottom 10% removed — clothing only",
        },
        "Flat-Lay Product Photo": {
            "filename": f"flat_lay_{char}.png" if flat_url else None,
            "note": "Generated by Gemini from text description",
        },
    }

    # ── Step 2: Test each shot x 3 methods ────────────────────────────

    all_results = []

    for sid, info in sorted(source_urls.items()):
        print(f"\n[S{sid:02d}: {info['name']}]")
        source_cdn = info["cdn_url"]

        result_entry = {
            "shot_id": sid,
            "name": info["name"],
            "source_filename": info["filename"],
            "_refs": refs_info,
        }

        # Method A: Cropped reference
        print(f"    A) Cropped ref...", end="", flush=True)
        t0 = time.time()
        try:
            res = gemini_swap_cropped(source_cdn, cropped_url)
            elapsed = time.time() - t0
            url = get_image_url(res)
            if url:
                fn = f"S{sid:02d}_a_cropped.png"
                download_image(url, output_dir / fn)
                print(f" OK ({elapsed:.1f}s)")
                result_entry["a_cropped_ref"] = {"filename": fn, "elapsed": elapsed}
            else:
                print(f" no image")
                result_entry["a_cropped_ref"] = {"error": "no image returned"}
        except Exception as e:
            print(f" ERROR: {e}")
            result_entry["a_cropped_ref"] = {"error": str(e)}

        # Method B: Text-only
        print(f"    B) Text-only...", end="", flush=True)
        t0 = time.time()
        try:
            res = gemini_swap_text_only(source_cdn, wardrobe_desc)
            elapsed = time.time() - t0
            url = get_image_url(res)
            if url:
                fn = f"S{sid:02d}_b_text.png"
                download_image(url, output_dir / fn)
                print(f" OK ({elapsed:.1f}s)")
                result_entry["b_text_only"] = {"filename": fn, "elapsed": elapsed}
            else:
                print(f" no image")
                result_entry["b_text_only"] = {"error": "no image returned"}
        except Exception as e:
            print(f" ERROR: {e}")
            result_entry["b_text_only"] = {"error": str(e)}

        # Method C: Flat-lay reference
        if flat_url:
            print(f"    C) Flat-lay ref...", end="", flush=True)
            t0 = time.time()
            try:
                res = gemini_swap_flat_lay(source_cdn, flat_url)
                elapsed = time.time() - t0
                url = get_image_url(res)
                if url:
                    fn = f"S{sid:02d}_c_flatlay.png"
                    download_image(url, output_dir / fn)
                    print(f" OK ({elapsed:.1f}s)")
                    result_entry["c_flat_lay"] = {"filename": fn, "elapsed": elapsed}
                else:
                    print(f" no image")
                    result_entry["c_flat_lay"] = {"error": "no image returned"}
            except Exception as e:
                print(f" ERROR: {e}")
                result_entry["c_flat_lay"] = {"error": str(e)}
        else:
            print(f"    C) Flat-lay ref... SKIP (flat-lay generation failed)")
            result_entry["c_flat_lay"] = {"error": "flat-lay generation failed"}

        all_results.append(result_entry)

    # ── Step 3: Build HTML ────────────────────────────────────────────

    html_path = build_html(all_results, hero_url, output_dir, args.episode, char)

    # Write JSON
    json_path = output_dir / f"gemini_swap_test_ep_{args.episode:03d}_{char}.json"
    with open(json_path, "w") as f:
        json.dump({
            "episode": args.episode,
            "character": char,
            "hero_url": hero_url,
            "cropped_ref_url": cropped_url,
            "flat_lay_url": flat_url,
            "wardrobe_desc": wardrobe_desc,
            "generated_at": datetime.now().isoformat(),
            "results": all_results,
        }, f, indent=2)

    print(f"\n{'=' * 65}")
    print(f"  RESULTS")
    print(f"{'=' * 65}")
    for mk in METHOD_KEYS:
        m = METHODS[mk]
        ok = sum(1 for r in all_results if r.get(mk, {}).get("filename"))
        fail = sum(1 for r in all_results if r.get(mk, {}).get("error"))
        print(f"  {m['label']:30s}  {ok} OK  {fail} fail")
    print(f"\n  HTML: {html_path}")
    print(f"{'=' * 65}")


if __name__ == "__main__":
    main()
