#!/usr/bin/env python3
"""
upscaler_shootout.py — Compare upscaler models on a single input image.

Tests: SeedVR2, Crystal Upscaler, Creative Upscaler, Topaz
Generates side-by-side HTML comparison page.

Usage:
    python3 upscaler_shootout.py <input_image> [--output-dir <dir>]
"""

import argparse
import os
import shutil
import sys
import time
import urllib.request
from pathlib import Path

try:
    import fal_client
except ImportError:
    print("ERROR: fal-client not installed. Run: pip install fal-client", file=sys.stderr)
    sys.exit(1)


UPSCALERS = {
    "seedvr2": {
        "name": "SeedVR2",
        "endpoint": "fal-ai/seedvr/upscale/image",
        "cost_note": "$0.001/MP",
        "args": {
            "upscale_factor": 2,
            "output_format": "png",
        },
    },
    "seedvr2_low_noise": {
        "name": "SeedVR2 (low noise)",
        "endpoint": "fal-ai/seedvr/upscale/image",
        "cost_note": "$0.001/MP",
        "args": {
            "upscale_factor": 2,
            "noise_scale": 0.02,
            "output_format": "png",
        },
    },
    "crystal": {
        "name": "Crystal Upscaler",
        "endpoint": "clarityai/crystal-upscaler",
        "cost_note": "$0.016/MP",
        "args": {
            "scale_factor": 2,
            "creativity": 1,
        },
    },
    "creative": {
        "name": "Creative Upscaler",
        "endpoint": "fal-ai/creative-upscaler",
        "cost_note": "~$0.05-0.10",
        "args": {
            "scale": 2,
            "creativity": 0.3,
            "detail": 2.0,
            "shape_preservation": 1.0,
            "prompt_suffix": " photorealistic cinematic portrait, detailed skin texture with visible pores, sharp iris detail",
            "num_inference_steps": 30,
        },
    },
    "creative_high": {
        "name": "Creative Upscaler (high creativity)",
        "endpoint": "fal-ai/creative-upscaler",
        "cost_note": "~$0.05-0.10",
        "args": {
            "scale": 2,
            "creativity": 0.6,
            "detail": 3.0,
            "shape_preservation": 0.5,
            "prompt_suffix": " photorealistic cinematic portrait, detailed skin texture with visible pores, sharp iris detail, electric blue glowing eyes",
            "num_inference_steps": 30,
        },
    },
    "topaz": {
        "name": "Topaz",
        "endpoint": "fal-ai/topaz/upscale/image",
        "cost_note": "~$0.01-0.02",
        "args": {
            "upscale_factor": 2,
        },
    },
}


def run_upscaler(key, config, image_url, output_path):
    """Run a single upscaler and save the result."""
    args = dict(config["args"])
    args["image_url"] = image_url

    print(f"  [{config['name']}] Running...", end=" ", flush=True)
    t0 = time.time()
    try:
        result = fal_client.run(config["endpoint"], arguments=args)
        elapsed = time.time() - t0

        # Extract image URL from result (different models use different keys)
        img_url = None
        if isinstance(result, dict):
            if "image" in result:
                img_url = result["image"]["url"] if isinstance(result["image"], dict) else result["image"]
            elif "output" in result:
                out = result["output"]
                if isinstance(out, dict) and "url" in out:
                    img_url = out["url"]
                elif isinstance(out, str):
                    img_url = out
            elif "images" in result and result["images"]:
                img_url = result["images"][0]["url"] if isinstance(result["images"][0], dict) else result["images"][0]
            elif "url" in result:
                img_url = result["url"]

        if img_url:
            output_path.parent.mkdir(parents=True, exist_ok=True)
            urllib.request.urlretrieve(img_url, str(output_path))
            print(f"OK ({elapsed:.1f}s)")
            return {"success": True, "elapsed": elapsed, "path": str(output_path)}
        else:
            print(f"FAILED: No image in response. Keys: {list(result.keys()) if isinstance(result, dict) else type(result)}")
            return {"success": False, "elapsed": elapsed, "error": f"No image URL found. Response keys: {list(result.keys()) if isinstance(result, dict) else 'not dict'}"}
    except Exception as e:
        elapsed = time.time() - t0
        print(f"ERROR ({elapsed:.1f}s): {str(e)[:100]}")
        return {"success": False, "elapsed": elapsed, "error": str(e)[:200]}


def generate_html(results, input_path, output_dir):
    """Generate side-by-side comparison HTML."""
    input_name = Path(input_path).name

    cards = []
    # Add original as first card
    orig_copy = output_dir / f"00_original.png"
    shutil.copy2(input_path, orig_copy)
    cards.append(f"""
        <div class="card">
            <h3>Original (NBP output)</h3>
            <p class="meta">Source image — no upscaling</p>
            <img src="00_original.png" alt="Original" onclick="openLightbox(this.src)">
        </div>
    """)

    for key, res in results.items():
        if not res["success"]:
            cards.append(f"""
                <div class="card failed">
                    <h3>{UPSCALERS[key]['name']}</h3>
                    <p class="meta">FAILED: {res.get('error', 'unknown')[:80]}</p>
                </div>
            """)
            continue

        fname = Path(res["path"]).name
        cards.append(f"""
            <div class="card">
                <h3>{UPSCALERS[key]['name']}</h3>
                <p class="meta">{res['elapsed']:.1f}s | {UPSCALERS[key]['cost_note']}</p>
                <img src="{fname}" alt="{key}" onclick="openLightbox(this.src)">
            </div>
        """)

    html = f"""<!DOCTYPE html>
<html><head><meta charset="utf-8">
<title>Upscaler Shootout — {input_name}</title>
<style>
    * {{ margin: 0; padding: 0; box-sizing: border-box; }}
    body {{ background: #111; color: #eee; font-family: -apple-system, sans-serif; padding: 20px; }}
    h1 {{ margin-bottom: 10px; font-size: 1.4em; }}
    .subtitle {{ color: #888; margin-bottom: 20px; font-size: 0.9em; }}
    .grid {{ display: grid; grid-template-columns: repeat(auto-fill, minmax(400px, 1fr)); gap: 16px; }}
    .card {{ background: #1a1a1a; border-radius: 8px; padding: 12px; cursor: pointer; transition: outline 0.15s; }}
    .card:hover {{ outline: 2px solid #555; }}
    .card.active-card {{ outline: 2px solid #4af; }}
    .card.failed {{ opacity: 0.5; cursor: default; }}
    .card h3 {{ font-size: 1em; margin-bottom: 4px; }}
    .card .meta {{ color: #888; font-size: 0.8em; margin-bottom: 8px; }}
    .card img {{ width: 100%; border-radius: 4px; cursor: zoom-in; }}
    .lightbox {{ display: none; position: fixed; top: 0; left: 0; width: 100vw; height: 100vh;
                 background: rgba(0,0,0,0.95); z-index: 1000;
                 justify-content: center; align-items: center; flex-direction: column; }}
    .lightbox.active {{ display: flex; }}
    .lightbox img {{ max-width: 92vw; max-height: 85vh; object-fit: contain; cursor: default; }}
    .lb-label {{ color: #fff; font-size: 1.1em; font-weight: 600; margin-bottom: 10px; text-align: center; }}
    .lb-counter {{ color: #666; font-size: 0.8em; margin-top: 8px; }}
    .lb-nav {{ position: absolute; top: 50%; transform: translateY(-50%); font-size: 3em;
               color: #555; cursor: pointer; user-select: none; padding: 20px; transition: color 0.15s; }}
    .lb-nav:hover {{ color: #fff; }}
    .lb-prev {{ left: 10px; }}
    .lb-next {{ right: 10px; }}
    .lb-close {{ position: absolute; top: 15px; right: 25px; font-size: 2em;
                 color: #555; cursor: pointer; transition: color 0.15s; }}
    .lb-close:hover {{ color: #fff; }}
    .hint {{ color: #555; font-size: 0.75em; margin-top: 10px; }}
</style>
</head><body>
<h1>Upscaler Shootout</h1>
<p class="subtitle">Input: {input_name} | Click any image to zoom | Arrow keys to navigate</p>
<div class="grid">
    {''.join(cards)}
</div>
<div class="lightbox" id="lb">
    <div class="lb-close" onclick="closeLightbox()">&times;</div>
    <div class="lb-nav lb-prev" onclick="navigate(-1)">&lsaquo;</div>
    <div class="lb-nav lb-next" onclick="navigate(1)">&rsaquo;</div>
    <div class="lb-label" id="lb-label"></div>
    <img id="lb-img" src="">
    <div class="lb-counter" id="lb-counter"></div>
</div>
<p class="hint">Click image to zoom &middot; Arrow keys or click arrows to navigate &middot; Esc to close</p>
<script>
const images = [];
const cards = document.querySelectorAll('.card:not(.failed)');
cards.forEach((card, i) => {{
    const img = card.querySelector('img');
    const label = card.querySelector('h3').textContent;
    const meta = card.querySelector('.meta').textContent;
    if (img) {{
        images.push({{ src: img.src, label: label, meta: meta, cardIndex: i }});
        img.onclick = () => openLightbox(images.length - 1);
    }}
}});

let currentIndex = 0;

function openLightbox(idx) {{
    currentIndex = idx;
    updateLightbox();
    document.getElementById('lb').classList.add('active');
    highlightCard();
}}

function closeLightbox() {{
    document.getElementById('lb').classList.remove('active');
    cards.forEach(c => c.classList.remove('active-card'));
}}

function navigate(dir) {{
    currentIndex = (currentIndex + dir + images.length) % images.length;
    updateLightbox();
    highlightCard();
}}

function updateLightbox() {{
    const item = images[currentIndex];
    document.getElementById('lb-img').src = item.src;
    document.getElementById('lb-label').textContent = item.label;
    document.getElementById('lb-counter').textContent =
        (currentIndex + 1) + ' / ' + images.length + '  ·  ' + item.meta;
}}

function highlightCard() {{
    cards.forEach(c => c.classList.remove('active-card'));
    const ci = images[currentIndex].cardIndex;
    if (cards[ci]) cards[ci].classList.add('active-card');
}}

document.addEventListener('keydown', (e) => {{
    const lb = document.getElementById('lb');
    if (!lb.classList.contains('active')) return;
    if (e.key === 'ArrowRight' || e.key === 'ArrowDown') {{ e.preventDefault(); navigate(1); }}
    else if (e.key === 'ArrowLeft' || e.key === 'ArrowUp') {{ e.preventDefault(); navigate(-1); }}
    else if (e.key === 'Escape') {{ closeLightbox(); }}
}});

document.getElementById('lb').addEventListener('click', (e) => {{
    if (e.target === document.getElementById('lb')) closeLightbox();
}});
</script>
</body></html>"""

    html_path = output_dir / "comparison.html"
    html_path.write_text(html)
    return html_path


def main():
    parser = argparse.ArgumentParser(description="Upscaler Shootout — compare upscaler models")
    parser.add_argument("input", help="Input image path")
    parser.add_argument("--output-dir", default=None, help="Output directory (default: next to input)")
    parser.add_argument("--upscalers", default=None, help="Comma-separated upscaler keys (default: all)")
    args = parser.parse_args()

    input_path = Path(args.input).resolve()
    if not input_path.is_file():
        print(f"ERROR: Input not found: {input_path}", file=sys.stderr)
        sys.exit(1)

    if args.output_dir:
        output_dir = Path(args.output_dir).resolve()
    else:
        output_dir = input_path.parent / "upscaler_shootout"
    output_dir.mkdir(parents=True, exist_ok=True)

    if args.upscalers:
        keys = [k.strip() for k in args.upscalers.split(",")]
    else:
        keys = list(UPSCALERS.keys())

    if not os.environ.get("FAL_KEY"):
        print("ERROR: FAL_KEY environment variable not set.", file=sys.stderr)
        sys.exit(1)

    # Upload input image
    print(f"\nInput: {input_path.name}")
    print(f"Output: {output_dir}")
    print(f"Upscalers: {len(keys)}\n")

    data = input_path.read_bytes()
    mime = "image/jpeg" if input_path.suffix.lower() in (".jpg", ".jpeg") else "image/png"
    print("Uploading input image...", end=" ", flush=True)
    image_url = fal_client.upload(data, mime, file_name="input.png")
    print("OK\n")

    # Run each upscaler
    results = {}
    for key in keys:
        if key not in UPSCALERS:
            print(f"  WARNING: Unknown upscaler '{key}', skipping")
            continue
        config = UPSCALERS[key]
        out_file = output_dir / f"{key}.png"
        results[key] = run_upscaler(key, config, image_url, out_file)
        time.sleep(1)

    # Generate comparison HTML
    html_path = generate_html(results, input_path, output_dir)

    # Summary
    print(f"\n{'='*50}")
    print("RESULTS")
    print(f"{'='*50}")
    succeeded = sum(1 for r in results.values() if r["success"])
    print(f"  Succeeded: {succeeded}/{len(results)}")
    for key, res in results.items():
        status = "OK" if res["success"] else "FAILED"
        elapsed = f"{res['elapsed']:.1f}s" if "elapsed" in res else "—"
        print(f"  {UPSCALERS[key]['name']:30s} {status:6s} {elapsed}")
    print(f"\n  Comparison: {html_path}")
    print(f"  Open: open \"{html_path}\"")


if __name__ == "__main__":
    main()
