#!/usr/bin/env python3
"""
Pass 3 Alternatives Test — Expression Change + SeedVR2 vs NBP

Tests fal.ai Expression Change and SeedVR2 as potential replacements for
NBP/Gemini 3 Pro in the three-pass pipeline. Runs on existing Pass 2 outputs
and generates a side-by-side comparison HTML.

Usage:
    python3 pass3_alternatives_test.py
"""

import sys
import time
import json
import shutil
from pathlib import Path

# Ensure fal_client is available
try:
    import fal_client
except ImportError:
    print("ERROR: fal_client not installed. Run: pip install fal-client")
    sys.exit(1)

# --- Configuration ---

SHOOTOUT_DIR = Path(__file__).resolve().parent.parent.parent / "leviathan" / "visual" / "lora_candidates" / "JINX" / "shootout"

# Two test cases: front exhausted and front focused (pre-dual-reference runs)
TEST_CASES = [
    {
        "name": "front_exhausted",
        "expression": "exhausted — heavy-lidded eyes, slight frown, drained hollow gaze",
        "run_dir": "front_exhausted — heavy-lidded eyes, slight frown, drained hollow gaze_20260214_142921",
    },
    {
        "name": "front_focused",
        "expression": "focused — narrowed eyes, set jaw, intent forward stare",
        "run_dir": "front_focused — narrowed eyes, set jaw, intent forward stare_20260214_142531",
    },
]

OUTPUT_DIR = SHOOTOUT_DIR / "pass3_alternatives_test"


def upload_to_fal(image_path: Path) -> str:
    """Upload a local image to fal.ai and return the URL."""
    # Read bytes and upload with a safe ASCII filename to avoid encoding errors
    data = image_path.read_bytes()
    import mimetypes
    mime = mimetypes.guess_type(str(image_path))[0] or "image/png"
    safe_name = "input.png"
    url = fal_client.upload(data, mime, file_name=safe_name)
    print(f"  Uploaded: {image_path.name} → {url[:80]}...")
    return url


def run_expression_change(image_url: str, expression: str) -> dict:
    """Run fal.ai Expression Change model."""
    print(f"  Running Expression Change: '{expression[:50]}...'")
    t0 = time.time()
    result = fal_client.run(
        "fal-ai/image-editing/expression-change",
        arguments={
            "image_url": image_url,
            "prompt": expression,
            "guidance_scale": 3.0,  # lower = less pose drift
            "num_inference_steps": 30,
        },
    )
    elapsed = time.time() - t0
    print(f"  Expression Change done in {elapsed:.1f}s")
    return {"result": result, "time": elapsed}


def run_seedvr2(image_url: str) -> dict:
    """Run SeedVR2 upscaler."""
    print(f"  Running SeedVR2 upscale...")
    t0 = time.time()
    result = fal_client.run(
        "fal-ai/seedvr/upscale/image",
        arguments={
            "image_url": image_url,
        },
    )
    elapsed = time.time() - t0
    print(f"  SeedVR2 done in {elapsed:.1f}s")
    return {"result": result, "time": elapsed}


def get_image_url(result: dict) -> str:
    """Extract image URL from fal.ai result (handles different response formats)."""
    if "image" in result:
        return result["image"]["url"]
    elif "images" in result:
        return result["images"][0]["url"]
    elif "output" in result:
        return result["output"]["url"] if isinstance(result["output"], dict) else result["output"]
    raise KeyError(f"Cannot find image URL in result: {list(result.keys())}")


def run_expression_then_seedvr2(image_url: str, expression: str) -> dict:
    """Run Expression Change then SeedVR2 on the result (chained)."""
    # First: expression change
    expr_result = run_expression_change(image_url, expression)
    expr_image_url = get_image_url(expr_result["result"])

    # Then: SeedVR2 on the expression-changed image
    print(f"  Running SeedVR2 on expression-changed output...")
    t0 = time.time()
    seed_result = fal_client.run(
        "fal-ai/seedvr/upscale/image",
        arguments={
            "image_url": expr_image_url,
        },
    )
    elapsed = time.time() - t0
    print(f"  SeedVR2 (chained) done in {elapsed:.1f}s")

    return {
        "expr_result": expr_result,
        "seed_result": {"result": seed_result, "time": elapsed},
        "total_time": expr_result["time"] + elapsed,
    }


def download_image(url: str, output_path: Path):
    """Download an image from URL to local path."""
    import urllib.request
    urllib.request.urlretrieve(url, output_path)
    print(f"  Saved: {output_path.name}")


def generate_comparison_html(test_results: list, output_dir: Path):
    """Generate a side-by-side comparison HTML."""
    html = """<!DOCTYPE html>
<html>
<head>
<title>Pass 3 Alternatives Test — Expression Change + SeedVR2 vs NBP</title>
<style>
    body { background: #1a1a1a; color: #e0e0e0; font-family: -apple-system, sans-serif; margin: 20px; }
    h1 { color: #fff; border-bottom: 2px solid #444; padding-bottom: 10px; }
    h2 { color: #aaa; margin-top: 40px; }
    .test-case { margin-bottom: 60px; }
    .row { display: flex; gap: 12px; margin-bottom: 20px; flex-wrap: wrap; }
    .col { flex: 1; min-width: 280px; max-width: 520px; }
    .col img { width: 100%; border: 2px solid #333; border-radius: 4px; cursor: pointer; }
    .col img:hover { border-color: #0af; }
    .col img.selected { border-color: #0f0; border-width: 3px; }
    .label { text-align: center; font-size: 13px; color: #888; margin-top: 4px; }
    .label strong { color: #ccc; }
    .time { color: #666; font-size: 11px; }
    .note { color: #999; font-size: 14px; margin: 10px 0; padding: 10px; background: #222; border-radius: 4px; }
    .section-label { font-size: 16px; color: #0af; margin: 20px 0 8px 0; font-weight: bold; }
    /* Lightbox */
    .lightbox { display: none; position: fixed; top: 0; left: 0; width: 100vw; height: 100vh;
                background: rgba(0,0,0,0.95); z-index: 1000; justify-content: center; align-items: center; }
    .lightbox.active { display: flex; }
    .lightbox img { max-width: 95vw; max-height: 95vh; object-fit: contain; }
    .lightbox-close { position: fixed; top: 20px; right: 30px; color: #fff; font-size: 30px; cursor: pointer; z-index: 1001; }
</style>
</head>
<body>
<h1>Pass 3 Alternatives Test</h1>
<p class="note">Comparing potential Pass 3 replacements on Pass 2 (Qwen Edit) outputs.<br>
<strong>Goal:</strong> Find a non-generative quality enhancer + expression model that won't override angle like NBP does.<br>
Click any image to enlarge. Look for: expression clarity, angle preservation, skin texture, eye sharpness, identity consistency.</p>
"""

    for tc in test_results:
        html += f'<div class="test-case">\n'
        html += f'<h2>{tc["name"].replace("_", " ").title()}: {tc["expression"]}</h2>\n'

        # Row 1: Pass 2 input
        html += '<div class="section-label">Input (Pass 2 — Qwen Edit output)</div>\n'
        html += '<div class="row">\n'
        html += f'<div class="col"><img src="{tc["pass2_file"]}" onclick="openLightbox(this)"><div class="label"><strong>Pass 2 (Qwen Edit)</strong><br>The input to all alternatives</div></div>\n'
        html += '</div>\n'

        # Row 2: All outputs side by side
        html += '<div class="section-label">Outputs — Side by Side Comparison</div>\n'
        html += '<div class="row">\n'

        # Old NBP (pre-dual-ref)
        html += f'<div class="col"><img src="{tc["nbp_file"]}" onclick="openLightbox(this)"><div class="label"><strong>Old NBP (Gemini 3 Pro)</strong><br>Pre-dual-ref, single Pass 2 input<br><span class="time">{tc["nbp_time"]}</span></div></div>\n'

        # Expression Change only
        html += f'<div class="col"><img src="{tc["expr_file"]}" onclick="openLightbox(this)"><div class="label"><strong>Expression Change</strong><br>fal.ai FLUX-based expression edit<br><span class="time">{tc["expr_time"]}</span></div></div>\n'

        # SeedVR2 only (quality enhancement, no expression)
        html += f'<div class="col"><img src="{tc["seedvr2_file"]}" onclick="openLightbox(this)"><div class="label"><strong>SeedVR2</strong><br>Non-generative quality upscale only<br><span class="time">{tc["seedvr2_time"]}</span></div></div>\n'

        # Expression Change → SeedVR2 (chained)
        html += f'<div class="col"><img src="{tc["chained_file"]}" onclick="openLightbox(this)"><div class="label"><strong>Expression → SeedVR2</strong><br>Expression Change then SeedVR2 quality<br><span class="time">{tc["chained_time"]}</span></div></div>\n'

        html += '</div>\n'  # row
        html += '</div>\n'  # test-case

    html += """
<div id="lightbox" class="lightbox" onclick="closeLightbox()">
    <span class="lightbox-close" onclick="closeLightbox()">&times;</span>
    <img id="lightbox-img" src="">
</div>
<script>
function openLightbox(el) {
    document.getElementById('lightbox-img').src = el.src;
    document.getElementById('lightbox').classList.add('active');
}
function closeLightbox() {
    document.getElementById('lightbox').classList.remove('active');
}
document.addEventListener('keydown', e => { if (e.key === 'Escape') closeLightbox(); });
</script>
</body>
</html>"""

    html_path = output_dir / "comparison.html"
    html_path.write_text(html)
    print(f"\nComparison HTML: {html_path}")
    return html_path


def main():
    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

    test_results = []

    for tc in TEST_CASES:
        print(f"\n{'='*60}")
        print(f"TEST: {tc['name']} — {tc['expression'][:50]}...")
        print(f"{'='*60}")

        run_dir = SHOOTOUT_DIR / tc["run_dir"]
        if not run_dir.exists():
            print(f"  ERROR: Run directory not found: {run_dir}")
            continue

        # Find Pass 2 and NBP files
        pass2_files = list(run_dir.glob("pass2_*.png"))
        nbp_files = list(run_dir.glob("pass3_nbp_*.png"))

        if not pass2_files:
            print(f"  ERROR: No Pass 2 file found in {run_dir}")
            continue
        if not nbp_files:
            print(f"  ERROR: No NBP file found in {run_dir}")
            continue

        pass2_path = pass2_files[0]
        nbp_path = nbp_files[0]

        # Read results.json for NBP timing
        results_json = run_dir / "results.json"
        nbp_time_str = "N/A"
        if results_json.exists():
            rdata = json.loads(results_json.read_text())
            for eng in rdata.get("engines", []):
                if eng.get("engine") == "nbp" and eng.get("time"):
                    nbp_time_str = f"{eng['time']:.1f}s"

        # Copy Pass 2 and NBP to output dir for comparison
        pass2_out = OUTPUT_DIR / f"{tc['name']}_pass2.png"
        nbp_out = OUTPUT_DIR / f"{tc['name']}_nbp.png"
        shutil.copy2(pass2_path, pass2_out)
        shutil.copy2(nbp_path, nbp_out)

        # Upload Pass 2 to fal.ai
        print(f"\n  Uploading Pass 2 image...")
        pass2_url = upload_to_fal(pass2_path)

        # Test 1: Expression Change
        print(f"\n  --- Test 1: Expression Change ---")
        expr_data = run_expression_change(pass2_url, tc["expression"])
        # Debug: print result keys
        expr_res = expr_data["result"]
        if "image" in expr_res:
            expr_img_url = expr_res["image"]["url"]
        elif "images" in expr_res:
            expr_img_url = expr_res["images"][0]["url"]
        elif "output" in expr_res:
            expr_img_url = expr_res["output"]["url"] if isinstance(expr_res["output"], dict) else expr_res["output"]
        else:
            print(f"  DEBUG: Expression result keys: {list(expr_res.keys())}")
            print(f"  DEBUG: Full result: {json.dumps(expr_res, indent=2, default=str)[:500]}")
            raise KeyError(f"Cannot find image URL in expression result: {list(expr_res.keys())}")
        expr_out = OUTPUT_DIR / f"{tc['name']}_expression_change.png"
        download_image(expr_img_url, expr_out)

        # Test 2: SeedVR2 (quality only, no expression)
        print(f"\n  --- Test 2: SeedVR2 (quality upscale only) ---")
        seed_data = run_seedvr2(pass2_url)
        seed_img_url = seed_data["result"]["image"]["url"]
        seed_out = OUTPUT_DIR / f"{tc['name']}_seedvr2.png"
        download_image(seed_img_url, seed_out)

        # Test 3: Expression Change → SeedVR2 (chained)
        print(f"\n  --- Test 3: Expression Change → SeedVR2 (chained) ---")
        chained_data = run_expression_then_seedvr2(pass2_url, tc["expression"])
        chained_img_url = chained_data["seed_result"]["result"]["image"]["url"]
        chained_out = OUTPUT_DIR / f"{tc['name']}_expr_then_seedvr2.png"
        download_image(chained_img_url, chained_out)

        test_results.append({
            "name": tc["name"],
            "expression": tc["expression"],
            "pass2_file": f"{tc['name']}_pass2.png",
            "nbp_file": f"{tc['name']}_nbp.png",
            "nbp_time": nbp_time_str,
            "expr_file": f"{tc['name']}_expression_change.png",
            "expr_time": f"{expr_data['time']:.1f}s",
            "seedvr2_file": f"{tc['name']}_seedvr2.png",
            "seedvr2_time": f"{seed_data['time']:.1f}s",
            "chained_file": f"{tc['name']}_expr_then_seedvr2.png",
            "chained_time": f"{chained_data['total_time']:.1f}s",
        })

    if test_results:
        html_path = generate_comparison_html(test_results, OUTPUT_DIR)
        print(f"\n{'='*60}")
        print(f"ALL TESTS COMPLETE")
        print(f"{'='*60}")
        print(f"  Output dir: {OUTPUT_DIR}")
        print(f"  Comparison: {html_path}")
        print(f"  Open: open \"{html_path}\"")


if __name__ == "__main__":
    main()
