#!/usr/bin/env python3
"""
ab_test_models.py — A/B Test T2I Models on Storyboard Shots

Compares generation quality between two fal.ai endpoints (e.g. Flux 2 vs Z-Image Turbo)
using the same prompts, LoRAs, and seeds. Generates a side-by-side HTML comparison page.

Usage:
    python3 ab_test_models.py leviathan/ --episode 1 --character jinx
    python3 ab_test_models.py leviathan/ --episode 1 --character jinx --shots 3,10
    python3 ab_test_models.py leviathan/ --episode 1 --character jinx --dry-run

Env vars:
    FAL_KEY — fal.ai API key (required)
"""

import argparse
import json
import sys
import time
from datetime import datetime
from pathlib import Path

# Add tools to path
_tools_dir = str(Path(__file__).resolve().parent)
if _tools_dir not in sys.path:
    sys.path.insert(0, _tools_dir)

from train_lora import load_registry, get_inference_config


# ── Model Configurations ─────────────────────────────────────────────────

MODELS = {
    "flux2": {
        "endpoint": "fal-ai/flux-2/lora",
        "params": {
            "num_inference_steps": 28,
            "guidance_scale": 2.5,
        },
        "label": "Flux 2 Dev",
    },
    "z_image": {
        "endpoint": "fal-ai/z-image/turbo/lora",
        "params": {
            "num_inference_steps": 8,
        },
        "label": "Z-Image Turbo",
    },
}

# Default image sizes
PORTRAIT = {"width": 768, "height": 1344}
SQUARE = {"width": 896, "height": 896}  # ECU: square → center-crop to 9:16
TRIPTYCH = {"width": 2048, "height": 1216}


def generate_t2i(endpoint: str, prompt: str, loras: list, seed: int,
                 image_size: dict, extra_params: dict, dry_run: bool = False) -> dict:
    """Generate a T2I image via fal.ai."""
    import fal_client

    request = {
        "prompt": prompt,
        "loras": loras,
        "seed": seed,
        "image_size": image_size,
        "enable_safety_checker": False,
        "output_format": "png",
        **extra_params,
    }

    if dry_run:
        return {"dry_run": True, "endpoint": endpoint, "prompt": prompt[:80] + "..."}

    try:
        result = fal_client.subscribe(endpoint, arguments=request)
        return result
    except Exception as e:
        return {"error": str(e)}


def download_image(url: str, path: Path) -> bool:
    """Download image from URL to local path."""
    import requests
    try:
        r = requests.get(url, timeout=60)
        r.raise_for_status()
        path.write_bytes(r.content)
        return True
    except Exception as e:
        print(f"    Download failed: {e}")
        return False


def build_comparison_html(results: list, output_dir: Path, episode: int) -> Path:
    """Generate side-by-side comparison HTML."""
    html = f"""<!DOCTYPE html>
<html><head><meta charset="UTF-8"><title>A/B Test — EP{episode:03d}</title>
<style>
  body {{ background: #0e0e12; color: #d0d0d8; font-family: -apple-system, sans-serif; margin: 0; padding: 16px; }}
  h1 {{ color: #00f0ff; font-size: 16px; letter-spacing: 3px; text-transform: uppercase; margin-bottom: 16px; }}
  .shot {{ background: #16161c; border: 1px solid #2e2e3e; border-radius: 8px; padding: 16px; margin-bottom: 16px; }}
  .shot-header {{ display: flex; align-items: center; gap: 12px; margin-bottom: 12px; }}
  .shot-header h2 {{ font-size: 14px; font-family: monospace; color: #d0d0d8; margin: 0; }}
  .shot-header .tag {{ font-size: 10px; padding: 2px 8px; border-radius: 3px; font-family: monospace; }}
  .tag-triptych {{ background: rgba(179,136,255,0.15); color: #b388ff; }}
  .tag-standard {{ background: rgba(33,150,243,0.15); color: #2196f3; }}
  .tag-held {{ background: rgba(96,96,112,0.15); color: #606070; }}
  .compare {{ display: flex; gap: 8px; }}
  .model-col {{ flex: 1; }}
  .model-col h3 {{ font-size: 11px; font-family: monospace; color: #9090a0; margin-bottom: 6px; letter-spacing: 1px; }}
  .model-col img {{ width: 100%; border-radius: 4px; border: 1px solid #2e2e3e; cursor: pointer; }}
  .model-col img:hover {{ border-color: #00f0ff; }}
  .meta {{ font-size: 10px; font-family: monospace; color: #606070; margin-top: 4px; }}
  .error {{ color: #f44336; font-size: 11px; }}
  .lightbox {{ display: none; position: fixed; inset: 0; background: rgba(0,0,0,0.92); z-index: 100; align-items: center; justify-content: center; }}
  .lightbox.open {{ display: flex; }}
  .lightbox img {{ max-width: 95vw; max-height: 95vh; object-fit: contain; }}
  .lightbox .close {{ position: absolute; top: 16px; right: 16px; font-size: 24px; color: #606070; cursor: pointer; background: none; border: none; }}
</style></head><body>
<h1>A/B Model Comparison — EP{episode:03d}</h1>
<p style="font-size:11px;color:#606070;font-family:monospace;margin-bottom:16px;">
  Generated {datetime.now().strftime('%Y-%m-%d %H:%M')} | {len(results)} shots
</p>
"""

    for r in results:
        approach = r.get("generation_approach", "unknown")
        tag_class = "tag-triptych" if "triptych" in approach else "tag-standard" if "flf" in approach else "tag-held"

        html += f"""<div class="shot">
  <div class="shot-header">
    <h2>S{r['shot_id']:02d}: {r['name']}</h2>
    <span class="tag {tag_class}">{approach.replace('_',' ')}</span>
  </div>
  <div class="compare">
"""
        for model_key in ["flux2", "z_image"]:
            model_result = r.get(model_key, {})
            label = MODELS[model_key]["label"]
            html += f'    <div class="model-col"><h3>{label}</h3>'

            if model_result.get("error"):
                html += f'<div class="error">{model_result["error"]}</div>'
            elif model_result.get("filename"):
                html += f'<img src="{model_result["filename"]}" onclick="openLB(this.src)" loading="lazy">'
                elapsed = model_result.get("elapsed", 0)
                html += f'<div class="meta">{elapsed:.1f}s | seed {r.get("seed", "?")}</div>'
            elif model_result.get("dry_run"):
                html += '<div class="meta">DRY RUN</div>'
            else:
                html += '<div class="meta">No result</div>'

            html += '</div>\n'

        html += "  </div>\n</div>\n"

    html += """
<div class="lightbox" id="lb" onclick="if(event.target===this)closeLB()">
  <button class="close" onclick="closeLB()">&times;</button>
  <img id="lb-img">
</div>
<script>
function openLB(src){document.getElementById('lb-img').src=src;document.getElementById('lb').classList.add('open');}
function closeLB(){document.getElementById('lb').classList.remove('open');}
document.addEventListener('keydown',e=>{if(e.key==='Escape')closeLB();});
</script>
</body></html>"""

    html_path = output_dir / f"ab_test_ep_{episode:03d}.html"
    html_path.write_text(html)
    return html_path


def main():
    parser = argparse.ArgumentParser(description="A/B test T2I models on storyboard shots")
    parser.add_argument("project_dir", help="Project directory (e.g. leviathan/)")
    parser.add_argument("-e", "--episode", type=int, required=True, help="Episode number")
    parser.add_argument("-c", "--character", help="Filter to shots containing this character (e.g. jinx)")
    parser.add_argument("--shots", help="Comma-separated shot IDs to test (default: all matching)")
    parser.add_argument("--models", default="flux2,z_image", help="Comma-separated model keys (default: flux2,z_image)")
    parser.add_argument("--seed", type=int, default=42, help="Generation seed")
    parser.add_argument("--solo-only", action="store_true", help="Only test shots where the character is alone (no dual-LoRA)")
    parser.add_argument("--dry-run", action="store_true", help="Print plan without generating")
    parser.add_argument("--z-image-lora", help="Override Z-Image LoRA URL (for freshly trained weights)")
    args = parser.parse_args()

    # Resolve project
    project_dir = Path(args.project_dir).resolve()
    if not project_dir.is_dir():
        project_dir = Path.cwd() / args.project_dir
    if not project_dir.is_dir():
        print(f"ERROR: Project directory not found: {args.project_dir}")
        sys.exit(1)

    # Load storyboard
    ep_str = str(args.episode).zfill(3)
    sb_path = project_dir / "storyboards" / f"storyboard_ep_{ep_str}.json"
    if not sb_path.is_file():
        print(f"ERROR: Storyboard not found: {sb_path}")
        sys.exit(1)

    try:
        with open(sb_path) as f:
            storyboard = json.load(f)
    except json.JSONDecodeError as e:
        print(f"ERROR: Invalid JSON in {sb_path}: {e}")
        sys.exit(1)

    # Load LoRA registry — get_inference_config returns flat dict per character
    raw_registry = load_registry(project_dir)
    lora_registry = {}
    for char_name in raw_registry:
        lora_registry[char_name] = get_inference_config(raw_registry, char_name)

    # Filter shots
    shot_ids = None
    if args.shots:
        shot_ids = [int(s.strip()) for s in args.shots.split(",")]

    model_keys = [m.strip() for m in args.models.split(",")]

    shots_to_test = []
    for shot in storyboard.get("shots", []):
        sid = shot.get("id")
        if shot_ids and sid not in shot_ids:
            continue

        chars = [c.lower() for c in shot.get("characters_in_shot", [])]

        if args.character:
            if args.character.lower() not in chars:
                continue
            if args.solo_only and len(chars) > 1:
                continue

        shots_to_test.append(shot)

    if not shots_to_test:
        print("No shots match the filter criteria.")
        sys.exit(0)

    # Output directory
    output_dir = project_dir / "storyboards" / "assets" / f"ep_{ep_str}" / "ab_test"
    output_dir.mkdir(parents=True, exist_ok=True)

    # Header
    print(f"{'=' * 60}")
    print(f"  A/B TEST — EP{args.episode:03d}")
    print(f"{'=' * 60}")
    print(f"  Character filter: {args.character or 'ALL'}")
    print(f"  Solo only: {args.solo_only}")
    print(f"  Shots: {len(shots_to_test)}")
    print(f"  Models: {', '.join(model_keys)}")
    print(f"  Seed: {args.seed}")
    print(f"  Output: {output_dir}")
    if args.z_image_lora:
        print(f"  Z-Image LoRA override: {args.z_image_lora}")
    print(f"{'=' * 60}")
    print()

    if args.dry_run:
        for shot in shots_to_test:
            chars = shot.get("characters_in_shot", [])
            approach = shot.get("generation_approach", "unknown")
            print(f"  S{shot['id']:02d}: {shot['name']:<30s} [{approach}] chars={chars}")
        print(f"\n  Total: {len(shots_to_test)} shots x {len(model_keys)} models = {len(shots_to_test) * len(model_keys)} generations")
        sys.exit(0)

    # Generate
    all_results = []

    for i, shot in enumerate(shots_to_test):
        sid = shot["id"]
        name = shot.get("name", "")
        chars = [c.lower() for c in shot.get("characters_in_shot", [])]
        approach = shot.get("generation_approach", "unknown")
        is_triptych = "triptych" in approach

        print(f"[{i+1}/{len(shots_to_test)}] S{sid:02d}: {name}")

        # Determine shot type for ECU handling
        shot_type = shot.get("shot_type", "").upper()
        is_ecu = shot_type == "ECU"

        # Build prompt from first_frame (or triptych_prompt for triptych)
        if is_triptych:
            prompt_field = shot.get("triptych_prompt", shot.get("first_frame", ""))
            img_size = TRIPTYCH
        elif is_ecu:
            prompt_field = shot.get("first_frame", "")
            img_size = SQUARE
        else:
            prompt_field = shot.get("first_frame", "")
            img_size = PORTRAIT

        result_entry = {
            "shot_id": sid,
            "name": name,
            "generation_approach": approach,
            "characters": chars,
            "seed": args.seed,
        }

        for model_key in model_keys:
            model = MODELS[model_key]

            # Build LoRA list for this model
            loras = []
            for char in chars:
                if char in lora_registry:
                    cfg = lora_registry[char]
                    scale = cfg.get("scale_solo", 0.9)
                    if len(chars) > 1:
                        scale = cfg.get("scale_dual", 0.5)

                    # Select model-specific LoRA
                    if model_key == "z_image":
                        if args.z_image_lora and char == args.character:
                            lora_path = args.z_image_lora
                        else:
                            lora_path = cfg.get("z_image_t2i_path") or cfg.get("t2i_path")
                    else:
                        lora_path = cfg.get("t2i_path")

                    if lora_path:
                        loras.append({"path": lora_path, "scale": scale})

            # ECU shots: skip LoRA triggers (detail shots don't need character identity)
            if is_ecu:
                full_prompt = prompt_field
                loras = []  # ECU skips LoRA entirely
            else:
                triggers = []
                for char in chars:
                    if char in lora_registry:
                        triggers.append(lora_registry[char].get("trigger", "") or "")
                trigger_prefix = ", ".join(t for t in triggers if t)
                full_prompt = f"{trigger_prefix}, {prompt_field}" if trigger_prefix else prompt_field

            if loras:
                for l in loras:
                    print(f"    {model['label']} LoRA: scale={l['scale']}, path=...{l['path'][-40:]}")
            else:
                print(f"    {model['label']}: NO LoRA")
            print(f"    {model['label']}: generating...", end="", flush=True)
            t0 = time.time()

            gen_result = generate_t2i(
                model["endpoint"], full_prompt, loras, args.seed,
                img_size, model["params"], dry_run=False,
            )

            elapsed = time.time() - t0

            if gen_result.get("error"):
                print(f" ERROR: {gen_result['error']}")
                result_entry[model_key] = {"error": gen_result["error"], "elapsed": elapsed}
            elif gen_result.get("images"):
                img_url = gen_result["images"][0].get("url", "")
                filename = f"S{sid:02d}_{model_key}_f1.png"
                if is_triptych:
                    filename = f"S{sid:02d}_{model_key}_strip.png"
                filepath = output_dir / filename

                if download_image(img_url, filepath):
                    # Center-crop square ECU images to 9:16
                    if is_ecu:
                        try:
                            from PIL import Image
                            img = Image.open(str(filepath))
                            w, h = img.size
                            target_w = int(h * 9 / 16)
                            left = (w - target_w) // 2
                            img = img.crop((left, 0, left + target_w, h))
                            img.save(str(filepath))
                        except Exception:
                            pass  # keep uncropped if PIL fails
                    print(f" OK ({elapsed:.1f}s)")
                    result_entry[model_key] = {
                        "filename": filename,
                        "elapsed": elapsed,
                        "seed": gen_result.get("seed", args.seed),
                    }
                else:
                    print(f" download failed ({elapsed:.1f}s)")
                    result_entry[model_key] = {"error": "download failed", "elapsed": elapsed}
            else:
                print(f" no images returned ({elapsed:.1f}s)")
                result_entry[model_key] = {"error": "no images in response", "elapsed": elapsed}

        all_results.append(result_entry)

    # Write comparison HTML
    html_path = build_comparison_html(all_results, output_dir, args.episode)

    # Write results JSON
    results_json_path = output_dir / f"ab_test_ep_{args.episode:03d}.json"
    with open(results_json_path, "w") as f:
        json.dump({
            "episode": args.episode,
            "character": args.character,
            "models": model_keys,
            "seed": args.seed,
            "generated_at": datetime.now().isoformat(),
            "results": all_results,
        }, f, indent=2)

    # Summary
    print()
    print(f"{'=' * 60}")
    print(f"  RESULTS")
    print(f"{'=' * 60}")
    for mk in model_keys:
        successes = sum(1 for r in all_results if r.get(mk, {}).get("filename"))
        errors = sum(1 for r in all_results if r.get(mk, {}).get("error"))
        print(f"  {MODELS[mk]['label']}: {successes} OK, {errors} failed")
    print(f"  HTML:    {html_path}")
    print(f"  JSON:    {results_json_path}")
    print(f"{'=' * 60}")


if __name__ == "__main__":
    main()
