#!/usr/bin/env python3
"""
ab_test_wardrobe.py — A/B Test: T2I+LoRA vs Flux 2 Edit (Reference-Conditioned) vs img2img+LoRA

Tests whether Flux 2 Edit with reference images produces better wardrobe/prop
consistency than pure T2I with LoRA, using the same shots from a storyboard.

Three generation methods compared:
  A) Z-Image Turbo T2I + LoRA     — current pipeline (pure text-to-image)
  B) Flux 2 Edit + reference       — multi-reference conditioning (wardrobe hero + location ref)
  C) Z-Image img2img + LoRA        — wardrobe hero as img2img base at low strength

The test generates a "wardrobe hero frame" first, then uses it as reference input
for methods B and C.

Usage:
    python3 ab_test_wardrobe.py leviathan/ --episode 1 --character jinx
    python3 ab_test_wardrobe.py leviathan/ --episode 1 --character jinx --shots 3,5,9,15
    python3 ab_test_wardrobe.py leviathan/ --episode 1 --character jinx --dry-run

Env vars:
    FAL_KEY — fal.ai API key (required)
"""

import argparse
import json
import sys
import time
from datetime import datetime
from pathlib import Path
from typing import Optional

# Add tools to path
_tools_dir = str(Path(__file__).resolve().parent)
if _tools_dir not in sys.path:
    sys.path.insert(0, _tools_dir)

from train_lora import load_registry, get_inference_config
from prompt_engine import PromptEngine


# ── Method Configurations ────────────────────────────────────────────────

METHODS = {
    "z_image_t2i": {
        "label": "Z-Image T2I + LoRA",
        "short": "Z-Image T2I",
        "description": "Pure text-to-image with character LoRA. Current pipeline approach.",
        "color": "#00e5ff",
    },
    "flux2_edit": {
        "label": "Flux 2 Edit + Refs",
        "short": "Flux 2 Edit",
        "description": "Multi-reference conditioning with wardrobe hero frame. No LoRA — identity from reference image.",
        "color": "#b388ff",
    },
    "z_image_img2img": {
        "label": "Z-Image img2img + LoRA",
        "short": "Z-Image img2img",
        "description": "Wardrobe hero as img2img base (strength 0.40) + character LoRA + prompt.",
        "color": "#69f0ae",
    },
}

# Image sizes
PORTRAIT = {"width": 768, "height": 1344}

# Wardrobe hero frame prompt template
WARDROBE_HERO_PROMPT = (
    "{trigger}, {visual}. "
    "Full body three-quarter view, standing upright, clean studio-quality lighting "
    "with dramatic side light, dark industrial background. "
    "Showing complete outfit clearly: {wardrobe_desc}. "
    "Photorealistic, ultra-detailed material textures, "
    "correct human anatomy, natural hands and fingers, sharp focus"
)


# ── Generation Functions ────────────────────────────────────────────────

def generate_z_image_t2i(prompt: str, loras: list, seed: int,
                          image_size: dict) -> dict:
    """Method A: Z-Image Turbo T2I + LoRA."""
    import fal_client

    request = {
        "prompt": prompt,
        "loras": loras,
        "seed": seed,
        "image_size": image_size,
        "num_inference_steps": 8,
        "enable_safety_checker": False,
        "output_format": "png",
    }

    result = fal_client.subscribe("fal-ai/z-image/turbo/lora", arguments=request)
    return result


def generate_flux2_edit(prompt: str, image_urls: list, seed: int,
                         image_size: dict) -> dict:
    """Method B: Flux 2 Edit with multi-reference images."""
    import fal_client

    request = {
        "prompt": prompt,
        "image_urls": image_urls,
        "seed": seed,
        "image_size": image_size,
        "num_inference_steps": 28,
        "guidance_scale": 3.5,
        "output_format": "png",
    }

    result = fal_client.subscribe("fal-ai/flux-2/edit", arguments=request)
    return result


def generate_z_image_img2img(prompt: str, loras: list, seed: int,
                              image_url: str, strength: float,
                              image_size: dict) -> dict:
    """Method C: Z-Image img2img + LoRA with wardrobe hero as base."""
    import fal_client

    request = {
        "prompt": prompt,
        "loras": loras,
        "seed": seed,
        "image_url": image_url,
        "strength": strength,
        "image_size": image_size,
        "num_inference_steps": 8,
        "enable_safety_checker": False,
        "output_format": "png",
    }

    result = fal_client.subscribe("fal-ai/z-image/turbo/image-to-image/lora",
                                   arguments=request)
    return result


def download_image(url: str, path: Path) -> bool:
    """Download image from URL to local path."""
    import requests
    try:
        r = requests.get(url, timeout=60)
        r.raise_for_status()
        path.write_bytes(r.content)
        return True
    except Exception as e:
        print(f"    Download failed: {e}")
        return False


def get_image_url(result: dict) -> Optional[str]:
    """Extract image URL from fal.ai result."""
    images = result.get("images", [])
    if images:
        return images[0].get("url")
    return None


# ── Prompt Building ─────────────────────────────────────────────────────

def build_shot_prompt(engine: PromptEngine, shot: dict,
                      frame_type: str = "hero") -> str:
    """Build a structured prompt for a shot using the prompt engine."""
    return engine.compile(shot, frame_type)


def build_flux2_edit_prompt(shot: dict, char_visual: str,
                            frame_type: str = "hero") -> str:
    """Build a Flux 2 Edit prompt that references input images.

    Image 1 = wardrobe hero frame (character in target outfit)
    Image 2 = (optional) location reference

    The prompt describes the scene and references the images.
    """
    # Get the action for this frame type
    if frame_type == "hero" and shot.get("hero_action"):
        action = shot["hero_action"]
    elif frame_type == "first" and shot.get("anticipation_action"):
        action = shot["anticipation_action"]
    elif frame_type == "last" and shot.get("aftermath_action"):
        action = shot["aftermath_action"]
    else:
        action = shot.get("action", shot.get("subject", ""))

    # Shot context
    shot_type = shot.get("shot_type", "MS")
    emotion = shot.get("emotion", "")
    atmosphere = shot.get("atmosphere", "")
    lighting = shot.get("lighting", "")

    # Build the edit prompt referencing images
    parts = []

    # Reference the wardrobe hero (image 1)
    parts.append(
        f"The woman from image 1 wearing her exact same outfit and gear"
    )

    # Shot-specific action
    if action:
        parts.append(action)

    # Shot type mapping
    type_map = {
        "ECU": "extreme close-up",
        "CU": "close-up portrait",
        "MCU": "medium close-up",
        "MS": "medium shot",
        "LS": "long shot, full body visible",
        "WIDE": "wide establishing shot",
        "POV": "point of view shot",
    }
    if shot_type in type_map:
        parts.append(type_map[shot_type])

    # Atmosphere + lighting
    if atmosphere:
        parts.append(atmosphere)
    if lighting:
        parts.append(lighting)

    # Emotion
    if emotion:
        parts.append(f"expression: {emotion.lower()}")

    # Quality
    parts.append(
        "Photorealistic, cinematic, Kodak Vision3 500T, visible grain, "
        "correct human anatomy, natural hands, sharp focus"
    )

    return ". ".join(p.rstrip(".") for p in parts if p)


# ── HTML Report ─────────────────────────────────────────────────────────

def build_comparison_html(results: list, hero_info: dict,
                           output_dir: Path, episode: int,
                           character: str) -> Path:
    """Generate a 3-column comparison HTML report."""

    method_keys = ["z_image_t2i", "flux2_edit", "z_image_img2img"]

    html = f"""<!DOCTYPE html>
<html><head><meta charset="UTF-8">
<title>Wardrobe Consistency Test — EP{episode:03d} — {character.title()}</title>
<style>
  * {{ box-sizing: border-box; margin: 0; padding: 0; }}
  body {{ background: #0a0a0f; color: #c8c8d4; font-family: 'SF Mono', 'Fira Code', 'Cascadia Code', monospace; }}

  /* Header */
  .header {{
    padding: 24px 32px 20px;
    border-bottom: 1px solid #1a1a28;
    background: linear-gradient(180deg, #0e0e16 0%, #0a0a0f 100%);
  }}
  .header h1 {{
    font-size: 13px;
    letter-spacing: 4px;
    text-transform: uppercase;
    color: #00e5ff;
    margin-bottom: 8px;
  }}
  .header .subtitle {{
    font-size: 11px;
    color: #505068;
    line-height: 1.6;
  }}
  .header .subtitle span {{ color: #707088; }}

  /* Hero frame section */
  .hero-section {{
    padding: 24px 32px;
    border-bottom: 1px solid #1a1a28;
    background: #0d0d14;
  }}
  .hero-section h2 {{
    font-size: 11px;
    letter-spacing: 3px;
    text-transform: uppercase;
    color: #b388ff;
    margin-bottom: 12px;
  }}
  .hero-content {{
    display: flex;
    gap: 20px;
    align-items: flex-start;
  }}
  .hero-content img {{
    height: 320px;
    width: auto;
    border-radius: 6px;
    border: 1px solid #2a2a3e;
    cursor: pointer;
  }}
  .hero-content img:hover {{ border-color: #b388ff; }}
  .hero-meta {{
    font-size: 10px;
    color: #505068;
    line-height: 1.8;
  }}
  .hero-meta .label {{ color: #707088; }}

  /* Method legend */
  .legend {{
    display: flex;
    gap: 24px;
    padding: 16px 32px;
    border-bottom: 1px solid #1a1a28;
    background: #0c0c12;
  }}
  .legend-item {{
    display: flex;
    align-items: center;
    gap: 8px;
  }}
  .legend-dot {{
    width: 8px;
    height: 8px;
    border-radius: 50%;
  }}
  .legend-label {{ font-size: 10px; color: #707088; }}
  .legend-desc {{ font-size: 9px; color: #404058; }}

  /* Shot rows */
  .shots {{ padding: 16px 32px; }}

  .shot {{
    background: #0e0e16;
    border: 1px solid #1a1a28;
    border-radius: 8px;
    margin-bottom: 16px;
    overflow: hidden;
  }}
  .shot-header {{
    display: flex;
    align-items: center;
    gap: 12px;
    padding: 12px 16px;
    border-bottom: 1px solid #1a1a28;
    background: #101018;
  }}
  .shot-header h3 {{
    font-size: 12px;
    color: #d0d0d8;
  }}
  .shot-header .tag {{
    font-size: 9px;
    padding: 2px 8px;
    border-radius: 3px;
    text-transform: uppercase;
    letter-spacing: 1px;
  }}
  .tag-shot-type {{ background: rgba(0,229,255,0.08); color: #00e5ff; }}
  .tag-approach {{ background: rgba(179,136,255,0.08); color: #b388ff; }}

  .shot-prompt {{
    padding: 8px 16px;
    font-size: 9px;
    color: #404058;
    border-bottom: 1px solid #14141e;
    line-height: 1.5;
    max-height: 40px;
    overflow: hidden;
  }}
  .shot-prompt:hover {{
    max-height: none;
    color: #606078;
  }}

  .compare {{
    display: grid;
    grid-template-columns: 1fr 1fr 1fr;
    gap: 1px;
    background: #1a1a28;
  }}

  .method-col {{
    background: #0e0e16;
    padding: 12px;
  }}
  .method-col h4 {{
    font-size: 9px;
    letter-spacing: 2px;
    text-transform: uppercase;
    margin-bottom: 8px;
  }}
  .method-col img {{
    width: 100%;
    border-radius: 4px;
    border: 1px solid #1a1a28;
    cursor: pointer;
    transition: border-color 0.15s;
  }}
  .method-col img:hover {{ border-color: #00e5ff; }}
  .method-meta {{
    display: flex;
    justify-content: space-between;
    margin-top: 6px;
    font-size: 9px;
    color: #404058;
  }}
  .method-meta .cost {{ color: #69f0ae; }}
  .error {{ color: #ef5350; font-size: 10px; padding: 20px; text-align: center; }}

  /* Summary */
  .summary {{
    padding: 24px 32px;
    border-top: 1px solid #1a1a28;
    background: #0c0c12;
  }}
  .summary h2 {{
    font-size: 11px;
    letter-spacing: 3px;
    text-transform: uppercase;
    color: #69f0ae;
    margin-bottom: 12px;
  }}
  .summary-grid {{
    display: grid;
    grid-template-columns: 1fr 1fr 1fr;
    gap: 16px;
  }}
  .summary-card {{
    background: #0e0e16;
    border: 1px solid #1a1a28;
    border-radius: 6px;
    padding: 16px;
  }}
  .summary-card h3 {{
    font-size: 10px;
    letter-spacing: 2px;
    text-transform: uppercase;
    margin-bottom: 8px;
  }}
  .summary-card .stat {{
    font-size: 20px;
    font-weight: 600;
    margin-bottom: 4px;
  }}
  .summary-card .stat-label {{
    font-size: 9px;
    color: #505068;
  }}

  /* Lightbox */
  .lightbox {{
    display: none;
    position: fixed;
    inset: 0;
    background: rgba(0,0,0,0.95);
    z-index: 100;
    align-items: center;
    justify-content: center;
  }}
  .lightbox.open {{ display: flex; }}
  .lightbox img {{
    max-width: 95vw;
    max-height: 95vh;
    object-fit: contain;
    border-radius: 4px;
  }}
  .lightbox .close {{
    position: absolute;
    top: 16px;
    right: 20px;
    font-size: 28px;
    color: #505068;
    cursor: pointer;
    background: none;
    border: none;
    font-family: monospace;
  }}
  .lightbox .close:hover {{ color: #c8c8d4; }}
  .lightbox .nav {{
    position: absolute;
    font-size: 11px;
    color: #505068;
    bottom: 20px;
    left: 50%;
    transform: translateX(-50%);
  }}
</style>
</head>
<body>

<div class="header">
  <h1>Wardrobe Consistency Test — EP{episode:03d}</h1>
  <div class="subtitle">
    <span>Character:</span> {character.title()} &nbsp;|&nbsp;
    <span>Generated:</span> {datetime.now().strftime('%Y-%m-%d %H:%M')} &nbsp;|&nbsp;
    <span>Shots:</span> {len(results)} &nbsp;|&nbsp;
    <span>Seed:</span> {hero_info.get('seed', 42)}
  </div>
  <div class="subtitle" style="margin-top:6px; color:#404058;">
    Testing whether reference-conditioned generation (Flux 2 Edit) produces more consistent
    wardrobe/props than pure text-to-image with LoRA.
  </div>
</div>

<div class="hero-section">
  <h2>Wardrobe Hero Frame (Reference Anchor)</h2>
  <div class="hero-content">
"""

    # Hero frame image
    hero_filename = hero_info.get("filename", "")
    hero_elapsed = hero_info.get("elapsed", 0)
    hero_prompt_display = hero_info.get("prompt", "")[:200]

    if hero_filename:
        html += f'    <img src="{hero_filename}" onclick="openLB(this.src)">\n'
    else:
        html += '    <div class="error">Hero frame generation failed</div>\n'

    html += f"""    <div class="hero-meta">
      <div><span class="label">Method:</span> Z-Image T2I + Jinx LoRA (clean reference shot)</div>
      <div><span class="label">Time:</span> {hero_elapsed:.1f}s</div>
      <div><span class="label">Purpose:</span> This image is passed as reference input to Flux 2 Edit
        and as img2img base to Z-Image img2img for every shot below.</div>
      <div style="margin-top:8px;"><span class="label">Prompt:</span> {hero_prompt_display}...</div>
    </div>
  </div>
</div>

<div class="legend">
"""

    for mk in method_keys:
        m = METHODS[mk]
        html += f"""  <div class="legend-item">
    <div class="legend-dot" style="background:{m['color']};"></div>
    <span class="legend-label">{m['short']}</span>
    <span class="legend-desc">{m['description'][:80]}</span>
  </div>
"""

    html += "</div>\n\n<div class=\"shots\">\n"

    # Shot rows
    for r in results:
        approach = r.get("generation_approach", "unknown")
        shot_type = r.get("shot_type", "MS")

        html += f"""<div class="shot">
  <div class="shot-header">
    <h3>S{r['shot_id']:02d}: {r['name']}</h3>
    <span class="tag tag-shot-type">{shot_type}</span>
    <span class="tag tag-approach">{approach.replace('_', ' ')}</span>
  </div>
  <div class="shot-prompt" title="Click to expand">{r.get('prompt_preview', '')[:200]}</div>
  <div class="compare">
"""

        for mk in method_keys:
            m = METHODS[mk]
            mr = r.get(mk, {})

            html += f'    <div class="method-col">\n'
            html += f'      <h4 style="color:{m["color"]}">{m["short"]}</h4>\n'

            if mr.get("error"):
                html += f'      <div class="error">{mr["error"]}</div>\n'
            elif mr.get("filename"):
                elapsed = mr.get("elapsed", 0)
                html += f'      <img src="{mr["filename"]}" onclick="openLB(this.src)" loading="lazy">\n'
                html += f'      <div class="method-meta"><span>{elapsed:.1f}s</span>'
                if mr.get("cost"):
                    html += f'<span class="cost">${mr["cost"]:.3f}</span>'
                html += '</div>\n'
            else:
                html += '      <div class="error">No result</div>\n'

            html += '    </div>\n'

        html += "  </div>\n</div>\n\n"

    html += "</div>\n"

    # Summary section
    html += '<div class="summary">\n  <h2>Aggregate Results</h2>\n  <div class="summary-grid">\n'

    for mk in method_keys:
        m = METHODS[mk]
        successes = sum(1 for r in results if r.get(mk, {}).get("filename"))
        errors = sum(1 for r in results if r.get(mk, {}).get("error"))
        total_time = sum(r.get(mk, {}).get("elapsed", 0) for r in results if r.get(mk, {}).get("filename"))
        avg_time = total_time / max(successes, 1)

        html += f"""    <div class="summary-card">
      <h3 style="color:{m['color']}">{m['short']}</h3>
      <div class="stat" style="color:{m['color']}">{successes}/{len(results)}</div>
      <div class="stat-label">shots generated</div>
      <div class="stat" style="color:{m['color']}; font-size:16px; margin-top:8px;">{avg_time:.1f}s</div>
      <div class="stat-label">avg generation time</div>
      <div style="font-size:9px; color:#404058; margin-top:8px;">{m['description']}</div>
    </div>
"""

    html += """  </div>
</div>

<div class="lightbox" id="lb" onclick="if(event.target===this)closeLB()">
  <button class="close" onclick="closeLB()">&times;</button>
  <img id="lb-img">
  <div class="nav">Click outside or press Escape to close</div>
</div>

<script>
function openLB(src) {
  document.getElementById('lb-img').src = src;
  document.getElementById('lb').classList.add('open');
}
function closeLB() {
  document.getElementById('lb').classList.remove('open');
}
document.addEventListener('keydown', e => {
  if (e.key === 'Escape') closeLB();
});
</script>
</body></html>"""

    html_path = output_dir / f"wardrobe_test_ep_{episode:03d}_{character}.html"
    html_path.write_text(html)
    return html_path


# ── Main ────────────────────────────────────────────────────────────────

def main():
    parser = argparse.ArgumentParser(
        description="A/B test wardrobe consistency: T2I+LoRA vs Flux 2 Edit vs img2img+LoRA"
    )
    parser.add_argument("project_dir", help="Project directory (e.g. leviathan/)")
    parser.add_argument("-e", "--episode", type=int, required=True)
    parser.add_argument("-c", "--character", required=True,
                        help="Character to test (e.g. jinx)")
    parser.add_argument("--shots", help="Comma-separated shot IDs (default: auto-select 6-8)")
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--img2img-strength", type=float, default=0.40,
                        help="Strength for method C img2img (default: 0.40)")
    parser.add_argument("--hero-url", help="Skip hero generation, use this CDN URL instead")
    parser.add_argument("--dry-run", action="store_true")
    args = parser.parse_args()

    # Resolve paths
    project_dir = Path(args.project_dir).resolve()
    if not project_dir.is_dir():
        project_dir = Path.cwd() / args.project_dir
    if not project_dir.is_dir():
        print(f"ERROR: Project directory not found: {args.project_dir}")
        sys.exit(1)

    # Load storyboard
    ep_str = str(args.episode).zfill(3)
    sb_path = project_dir / "storyboards" / f"storyboard_ep_{ep_str}.json"
    if not sb_path.is_file():
        print(f"ERROR: Storyboard not found: {sb_path}")
        sys.exit(1)

    try:
        with open(sb_path) as f:
            storyboard = json.load(f)
    except json.JSONDecodeError as e:
        print(f"ERROR: Invalid JSON in {sb_path}: {e}")
        sys.exit(1)

    # Load LoRA registry
    raw_registry = load_registry(project_dir)
    lora_registry = {}
    for char_name in raw_registry:
        lora_registry[char_name] = get_inference_config(raw_registry, char_name)

    char = args.character.lower()
    if char not in lora_registry:
        print(f"ERROR: Character '{char}' not found in LoRA registry")
        sys.exit(1)

    char_config = lora_registry[char]
    char_data = storyboard.get("characters", {}).get(char, {})
    char_visual = char_data.get("visual", "")

    # Prompt engine for structured prompts
    flat_registry = {}
    for name, cfg in lora_registry.items():
        flat_registry[name] = cfg
    engine = PromptEngine(project_dir, storyboard, flat_registry, model_key="z_image")

    # Filter shots — solo shots for this character
    shot_ids = None
    if args.shots:
        shot_ids = [int(s.strip()) for s in args.shots.split(",")]

    shots_to_test = []
    for shot in storyboard.get("shots", []):
        sid = shot.get("id")
        if shot_ids and sid not in shot_ids:
            continue

        chars = [c.lower() for c in shot.get("characters_in_shot", [])]
        if char not in chars:
            continue
        # Solo shots only — multi-character needs different handling
        if len(chars) > 1:
            continue

        shots_to_test.append(shot)

    if not shot_ids:
        # Auto-select: pick diverse shot types, max 8
        type_seen = {}
        filtered = []
        for shot in shots_to_test:
            st = shot.get("shot_type", "MS")
            if type_seen.get(st, 0) < 2:
                filtered.append(shot)
                type_seen[st] = type_seen.get(st, 0) + 1
            if len(filtered) >= 8:
                break
        shots_to_test = filtered

    if not shots_to_test:
        print("No matching shots found.")
        sys.exit(0)

    # Output directory
    output_dir = project_dir / "storyboards" / "assets" / f"ep_{ep_str}" / "wardrobe_test"
    output_dir.mkdir(parents=True, exist_ok=True)

    # Build LoRA list for Z-Image methods
    z_lora_path = char_config.get("z_image_t2i_path") or char_config.get("t2i_path")
    z_loras = []
    if z_lora_path:
        z_loras = [{"path": z_lora_path, "scale": char_config.get("scale_solo", 0.9)}]

    trigger = char_config.get("trigger", "")

    # Header
    print(f"{'=' * 65}")
    print(f"  WARDROBE CONSISTENCY TEST — EP{args.episode:03d} — {char.upper()}")
    print(f"{'=' * 65}")
    print(f"  Shots: {len(shots_to_test)}")
    print(f"  Methods: Z-Image T2I, Flux 2 Edit, Z-Image img2img")
    print(f"  Seed: {args.seed}")
    print(f"  LoRA: {z_lora_path or 'NONE'}".rstrip())
    print(f"  img2img strength: {args.img2img_strength}")
    print(f"  Output: {output_dir}")
    print(f"{'=' * 65}")

    if args.dry_run:
        print(f"\n  WARDROBE HERO FRAME: Would generate clean reference shot")
        print(f"  Trigger: {trigger}")
        print()
        for shot in shots_to_test:
            st = shot.get("shot_type", "MS")
            approach = shot.get("generation_approach", "?")
            print(f"  S{shot['id']:02d}: {shot['name']:<30s} [{st}] [{approach}]")
        total = len(shots_to_test) * 3 + 1  # +1 for hero
        print(f"\n  Total: 1 hero + {len(shots_to_test)} x 3 methods = {total} generations")
        sys.exit(0)

    # ── Step 1: Generate Wardrobe Hero Frame ──────────────────────────

    hero_info = {}

    if args.hero_url:
        print(f"\n  Using provided hero URL: {args.hero_url[:60]}...")
        hero_info = {
            "url": args.hero_url,
            "filename": None,
            "elapsed": 0,
            "seed": args.seed,
            "prompt": "(provided externally)",
        }
    else:
        print(f"\n  Generating wardrobe hero frame...", end="", flush=True)

        # Build wardrobe description from character data
        wardrobe_desc = char_data.get("hair_makeup", "")
        if char_visual:
            wardrobe_desc = char_visual

        hero_prompt = WARDROBE_HERO_PROMPT.format(
            trigger=trigger,
            visual=char_visual,
            wardrobe_desc=wardrobe_desc,
        )

        t0 = time.time()
        try:
            hero_result = generate_z_image_t2i(
                hero_prompt, z_loras, args.seed, PORTRAIT,
            )
            hero_elapsed = time.time() - t0
            hero_url = get_image_url(hero_result)

            if hero_url:
                hero_filename = f"wardrobe_hero_{char}.png"
                hero_path = output_dir / hero_filename
                if download_image(hero_url, hero_path):
                    print(f" OK ({hero_elapsed:.1f}s)")
                    hero_info = {
                        "url": hero_url,
                        "filename": hero_filename,
                        "elapsed": hero_elapsed,
                        "seed": hero_result.get("seed", args.seed),
                        "prompt": hero_prompt[:200],
                    }
                else:
                    print(f" download failed")
                    sys.exit(1)
            else:
                print(f" no image returned")
                print(f"  Result: {json.dumps(hero_result, indent=2)[:500]}")
                sys.exit(1)
        except Exception as e:
            print(f" ERROR: {e}")
            sys.exit(1)

    hero_cdn_url = hero_info["url"]

    # ── Step 2: Generate Each Shot x 3 Methods ──────────────────────

    all_results = []

    for i, shot in enumerate(shots_to_test):
        sid = shot["id"]
        name = shot.get("name", "")
        shot_type = shot.get("shot_type", "MS")
        approach = shot.get("generation_approach", "unknown")

        print(f"\n[{i+1}/{len(shots_to_test)}] S{sid:02d}: {name} [{shot_type}]")

        # Build prompts
        z_image_prompt = build_shot_prompt(engine, shot, "hero")
        flux_edit_prompt = build_flux2_edit_prompt(shot, char_visual, "hero")

        result_entry = {
            "shot_id": sid,
            "name": name,
            "shot_type": shot_type,
            "generation_approach": approach,
            "prompt_preview": z_image_prompt[:200],
            "seed": args.seed,
        }

        # ── Method A: Z-Image T2I + LoRA ──

        print(f"    A) Z-Image T2I + LoRA...", end="", flush=True)
        t0 = time.time()
        try:
            res_a = generate_z_image_t2i(z_image_prompt, z_loras, args.seed, PORTRAIT)
            elapsed_a = time.time() - t0
            url_a = get_image_url(res_a)
            if url_a:
                fn_a = f"S{sid:02d}_a_z_t2i.png"
                if download_image(url_a, output_dir / fn_a):
                    print(f" OK ({elapsed_a:.1f}s)")
                    result_entry["z_image_t2i"] = {
                        "filename": fn_a, "elapsed": elapsed_a,
                        "seed": res_a.get("seed", args.seed),
                    }
                else:
                    result_entry["z_image_t2i"] = {"error": "download failed", "elapsed": elapsed_a}
            else:
                print(f" no image")
                result_entry["z_image_t2i"] = {"error": "no image returned", "elapsed": elapsed_a}
        except Exception as e:
            elapsed_a = time.time() - t0
            print(f" ERROR: {e}")
            result_entry["z_image_t2i"] = {"error": str(e), "elapsed": elapsed_a}

        # ── Method B: Flux 2 Edit + References ──

        print(f"    B) Flux 2 Edit + refs...", end="", flush=True)
        t0 = time.time()
        try:
            res_b = generate_flux2_edit(
                flux_edit_prompt,
                image_urls=[hero_cdn_url],
                seed=args.seed,
                image_size=PORTRAIT,
            )
            elapsed_b = time.time() - t0
            url_b = get_image_url(res_b)
            if url_b:
                fn_b = f"S{sid:02d}_b_flux_edit.png"
                if download_image(url_b, output_dir / fn_b):
                    print(f" OK ({elapsed_b:.1f}s)")
                    result_entry["flux2_edit"] = {
                        "filename": fn_b, "elapsed": elapsed_b,
                        "seed": res_b.get("seed", args.seed),
                    }
                else:
                    result_entry["flux2_edit"] = {"error": "download failed", "elapsed": elapsed_b}
            else:
                print(f" no image")
                result_entry["flux2_edit"] = {"error": "no image returned", "elapsed": elapsed_b}
        except Exception as e:
            elapsed_b = time.time() - t0
            print(f" ERROR: {e}")
            result_entry["flux2_edit"] = {"error": str(e), "elapsed": elapsed_b}

        # ── Method C: Z-Image img2img + LoRA ──

        print(f"    C) Z-Image img2img + LoRA (str={args.img2img_strength})...", end="", flush=True)
        t0 = time.time()
        try:
            res_c = generate_z_image_img2img(
                z_image_prompt, z_loras, args.seed,
                hero_cdn_url, args.img2img_strength, PORTRAIT,
            )
            elapsed_c = time.time() - t0
            url_c = get_image_url(res_c)
            if url_c:
                fn_c = f"S{sid:02d}_c_z_img2img.png"
                if download_image(url_c, output_dir / fn_c):
                    print(f" OK ({elapsed_c:.1f}s)")
                    result_entry["z_image_img2img"] = {
                        "filename": fn_c, "elapsed": elapsed_c,
                        "seed": res_c.get("seed", args.seed),
                    }
                else:
                    result_entry["z_image_img2img"] = {"error": "download failed", "elapsed": elapsed_c}
            else:
                print(f" no image")
                result_entry["z_image_img2img"] = {"error": "no image returned", "elapsed": elapsed_c}
        except Exception as e:
            elapsed_c = time.time() - t0
            print(f" ERROR: {e}")
            result_entry["z_image_img2img"] = {"error": str(e), "elapsed": elapsed_c}

        all_results.append(result_entry)

    # ── Step 3: Build HTML Report ─────────────────────────────────────

    html_path = build_comparison_html(
        all_results, hero_info, output_dir, args.episode, char,
    )

    # Write results JSON
    json_path = output_dir / f"wardrobe_test_ep_{args.episode:03d}_{char}.json"
    with open(json_path, "w") as f:
        json.dump({
            "episode": args.episode,
            "character": char,
            "seed": args.seed,
            "hero_info": hero_info,
            "img2img_strength": args.img2img_strength,
            "generated_at": datetime.now().isoformat(),
            "results": all_results,
        }, f, indent=2)

    # Summary
    print(f"\n{'=' * 65}")
    print(f"  RESULTS")
    print(f"{'=' * 65}")
    for mk in ["z_image_t2i", "flux2_edit", "z_image_img2img"]:
        successes = sum(1 for r in all_results if r.get(mk, {}).get("filename"))
        errors = sum(1 for r in all_results if r.get(mk, {}).get("error"))
        times = [r.get(mk, {}).get("elapsed", 0) for r in all_results if r.get(mk, {}).get("filename")]
        avg = sum(times) / max(len(times), 1)
        print(f"  {METHODS[mk]['label']:30s}  {successes} OK  {errors} fail  avg {avg:.1f}s")
    print(f"\n  HTML: {html_path}")
    print(f"  JSON: {json_path}")
    print(f"{'=' * 65}")


if __name__ == "__main__":
    main()
