#!/usr/bin/env python3
"""
ab_test_comprehensive.py — Comprehensive visual quality test across 5 generation methods.

Tests Z-Image Turbo vs Base vs ControlNet vs Gemini wardrobe swap
using the same shots, seed, and LoRA for direct comparison.

Methods:
  A) Z-Image Turbo T2I + LoRA        — current pipeline baseline (8 steps, no CFG)
  B) Z-Image Base T2I + LoRA         — higher quality variant (28 steps, CFG 4.0)
  C) Z-Image Base T2I + LoRA + neg   — same as B with negative prompt for anti-deformity
  D) Z-Image ControlNet depth + LoRA — depth map from wardrobe hero as structural guide
  E) Gemini wardrobe swap            — generate with A, then swap wardrobe via Gemini edit

Usage:
    python3 ab_test_comprehensive.py leviathan/ --episode 1 --character jinx
    python3 ab_test_comprehensive.py leviathan/ --episode 1 --character jinx --shots 3,5,9
    python3 ab_test_comprehensive.py leviathan/ --episode 1 --character jinx --dry-run

Env vars:
    FAL_KEY — fal.ai API key (required)
"""

import argparse
import json
import sys
import time
from datetime import datetime
from pathlib import Path
from typing import Optional

# Add tools to path
_tools_dir = str(Path(__file__).resolve().parent)
if _tools_dir not in sys.path:
    sys.path.insert(0, _tools_dir)

# CP-3 Phase 8 (2026-04-26): post-stub-deletion bootstrap.
# PromptEngine now lives in recoil/lib/prompt_compiler.py (migrated Phase 5).
# Match the sys.path-relative pattern from the deleted Phase 5 stub.
_LIB_DIR = str(Path(__file__).resolve().parent.parent / "lib")
if _LIB_DIR not in sys.path:
    sys.path.insert(0, _LIB_DIR)

from train_lora import load_registry, get_inference_config
from prompt_compiler import PromptEngine


# ── Method Configurations ────────────────────────────────────────────────

METHODS = {
    "a_turbo_t2i": {
        "label": "Z-Image Turbo T2I + LoRA",
        "short": "Turbo T2I",
        "description": "Current pipeline. 8 steps, no CFG, no negative prompt.",
        "color": "#00e5ff",
        "cost_per_mp": 0.0085,
    },
    "b_base_t2i": {
        "label": "Z-Image Base T2I + LoRA",
        "short": "Base T2I",
        "description": "Higher quality. 28 steps, CFG 4.0, no negative prompt.",
        "color": "#b388ff",
        "cost_per_mp": 0.012,
    },
    "c_base_neg": {
        "label": "Z-Image Base + LoRA + Neg",
        "short": "Base + Neg",
        "description": "Same as Base but with anti-deformity negative prompt.",
        "color": "#ffab40",
        "cost_per_mp": 0.012,
    },
    "d_controlnet": {
        "label": "ControlNet Depth + LoRA",
        "short": "CtrlNet Depth",
        "description": "Depth map from hero frame as structural guide. Turbo 8 steps.",
        "color": "#69f0ae",
        "cost_per_mp": 0.0065,
    },
    "e_gemini_swap": {
        "label": "Gemini Wardrobe Swap",
        "short": "Gemini Swap",
        "description": "Generate with Turbo T2I, then swap wardrobe via Gemini 2.5 Flash.",
        "color": "#ff5252",
        "cost_per_request": 0.04,
    },
}

METHOD_KEYS = list(METHODS.keys())

# Image sizes
PORTRAIT = {"width": 768, "height": 1344}
PORTRAIT_MP = (768 * 1344) / 1_000_000  # 1.032 MP

# Negative prompt for method C
NEGATIVE_PROMPT = (
    "deformed hands, extra fingers, mutated hands, poorly drawn hands, "
    "extra limbs, fused fingers, too many fingers, long neck, "
    "blurry, low quality, illustration, cartoon, painting, drawing, "
    "3d render, anime, cgi, digital art, smooth skin"
)

# Wardrobe hero frame prompt template
WARDROBE_HERO_PROMPT = (
    "{trigger}, {visual}. "
    "Full body three-quarter view, standing upright, clean studio-quality lighting "
    "with dramatic side light, dark industrial background. "
    "Showing complete outfit clearly: {wardrobe_desc}. "
    "Photorealistic, ultra-detailed material textures, "
    "correct human anatomy, natural hands and fingers, sharp focus"
)


# ── Generation Functions ────────────────────────────────────────────────

def generate_turbo_t2i(prompt: str, loras: list, seed: int,
                       image_size: dict) -> dict:
    """Method A: Z-Image Turbo T2I + LoRA."""
    import fal_client
    return fal_client.subscribe("fal-ai/z-image/turbo/lora", arguments={
        "prompt": prompt,
        "loras": loras,
        "seed": seed,
        "image_size": image_size,
        "num_inference_steps": 8,
        "enable_safety_checker": False,
        "output_format": "png",
    })


def generate_base_t2i(prompt: str, loras: list, seed: int,
                      image_size: dict,
                      negative_prompt: str = "",
                      guidance_scale: float = 4.0) -> dict:
    """Methods B/C: Z-Image Base T2I + LoRA (optional negative prompt)."""
    import fal_client
    request = {
        "prompt": prompt,
        "loras": loras,
        "seed": seed,
        "image_size": image_size,
        "num_inference_steps": 28,
        "guidance_scale": guidance_scale,
        "enable_safety_checker": False,
        "output_format": "png",
    }
    if negative_prompt:
        request["negative_prompt"] = negative_prompt
    return fal_client.subscribe("fal-ai/z-image/base/lora", arguments=request)


def generate_controlnet_depth(prompt: str, loras: list, seed: int,
                              image_size: dict,
                              control_image_url: str,
                              control_scale: float = 0.5) -> dict:
    """Method D: Z-Image Turbo ControlNet (depth) + LoRA."""
    import fal_client
    return fal_client.subscribe("fal-ai/z-image/turbo/controlnet/lora", arguments={
        "prompt": prompt,
        "loras": loras,
        "seed": seed,
        "image_url": control_image_url,
        "preprocess": "depth",
        "control_scale": control_scale,
        "control_start": 0.0,
        "control_end": 0.8,
        "image_size": image_size,
        "num_inference_steps": 8,
        "enable_safety_checker": False,
        "output_format": "png",
    })


def generate_gemini_swap(source_image_url: str, hero_image_url: str,
                         character_name: str) -> dict:
    """Method E: Gemini 2.5 Flash Image wardrobe swap."""
    import fal_client
    return fal_client.subscribe("fal-ai/gemini-25-flash-image/edit", arguments={
        "prompt": (
            f"Change this character's clothing and gear to exactly match the outfit "
            f"shown in the second reference image. Preserve the character's face, "
            f"expression, pose, body position, background environment, and lighting "
            f"exactly as they are. Only change the wardrobe and accessories to match "
            f"the reference outfit. The result should look like the same photograph "
            f"but with different clothes."
        ),
        "image_urls": [source_image_url, hero_image_url],
        "num_images": 1,
        "output_format": "png",
    })


def download_image(url: str, path: Path) -> bool:
    """Download image from URL to local path."""
    import requests
    try:
        r = requests.get(url, timeout=60)
        r.raise_for_status()
        path.write_bytes(r.content)
        return True
    except Exception as e:
        print(f"    Download failed: {e}")
        return False


def get_image_url(result: dict) -> Optional[str]:
    """Extract image URL from fal.ai result."""
    images = result.get("images", [])
    if images:
        return images[0].get("url")
    return None


# ── HTML Report ─────────────────────────────────────────────────────────

def build_comparison_html(results: list, hero_info: dict,
                          output_dir: Path, episode: int,
                          character: str,
                          active_methods: list = None) -> Path:
    """Generate a comparison HTML report (adapts columns to active methods)."""
    show_methods = [mk for mk in METHOD_KEYS if not active_methods or mk[0] in active_methods]
    n_cols = len(show_methods) or 5

    html = f"""<!DOCTYPE html>
<html><head><meta charset="UTF-8">
<title>Comprehensive Quality Test — EP{episode:03d} — {character.title()}</title>
<style>
  * {{ box-sizing: border-box; margin: 0; padding: 0; }}
  body {{ background: #0a0a0f; color: #c8c8d4; font-family: 'SF Mono', 'Fira Code', 'Cascadia Code', monospace; }}

  /* Header */
  .header {{
    padding: 24px 32px 20px;
    border-bottom: 1px solid #1a1a28;
    background: linear-gradient(180deg, #0e0e16 0%, #0a0a0f 100%);
  }}
  .header h1 {{
    font-size: 13px;
    letter-spacing: 4px;
    text-transform: uppercase;
    color: #00e5ff;
    margin-bottom: 8px;
  }}
  .header .subtitle {{
    font-size: 11px;
    color: #505068;
    line-height: 1.6;
  }}
  .header .subtitle span {{ color: #707088; }}

  /* Hero frame section */
  .hero-section {{
    padding: 24px 32px;
    border-bottom: 1px solid #1a1a28;
    background: #0d0d14;
  }}
  .hero-section h2 {{
    font-size: 11px;
    letter-spacing: 3px;
    text-transform: uppercase;
    color: #b388ff;
    margin-bottom: 12px;
  }}
  .hero-content {{
    display: flex;
    gap: 20px;
    align-items: flex-start;
  }}
  .hero-content img {{
    height: 280px;
    width: auto;
    border-radius: 6px;
    border: 1px solid #2a2a3e;
    cursor: pointer;
  }}
  .hero-content img:hover {{ border-color: #b388ff; }}
  .hero-meta {{
    font-size: 10px;
    color: #505068;
    line-height: 1.8;
  }}
  .hero-meta .label {{ color: #707088; }}

  /* Method legend */
  .legend {{
    display: flex;
    gap: 16px;
    padding: 14px 32px;
    border-bottom: 1px solid #1a1a28;
    background: #0c0c12;
    flex-wrap: wrap;
  }}
  .legend-item {{
    display: flex;
    align-items: center;
    gap: 6px;
  }}
  .legend-dot {{
    width: 8px;
    height: 8px;
    border-radius: 50%;
    flex-shrink: 0;
  }}
  .legend-label {{ font-size: 10px; color: #707088; white-space: nowrap; }}
  .legend-desc {{ font-size: 9px; color: #404058; }}

  /* Shot rows */
  .shots {{ padding: 16px 32px; }}

  .shot {{
    background: #0e0e16;
    border: 1px solid #1a1a28;
    border-radius: 8px;
    margin-bottom: 20px;
    overflow: hidden;
  }}
  .shot-header {{
    display: flex;
    align-items: center;
    gap: 12px;
    padding: 12px 16px;
    border-bottom: 1px solid #1a1a28;
    background: #101018;
  }}
  .shot-header h3 {{
    font-size: 12px;
    color: #d0d0d8;
  }}
  .shot-header .tag {{
    font-size: 9px;
    padding: 2px 8px;
    border-radius: 3px;
    text-transform: uppercase;
    letter-spacing: 1px;
  }}
  .tag-shot-type {{ background: rgba(0,229,255,0.08); color: #00e5ff; }}
  .tag-approach {{ background: rgba(179,136,255,0.08); color: #b388ff; }}

  .shot-prompt {{
    padding: 8px 16px;
    font-size: 9px;
    color: #404058;
    border-bottom: 1px solid #14141e;
    line-height: 1.5;
    max-height: 36px;
    overflow: hidden;
    cursor: pointer;
  }}
  .shot-prompt:hover {{
    max-height: none;
    color: #606078;
  }}

  .compare {{
    display: grid;
    grid-template-columns: repeat({n_cols}, 1fr);
    gap: 1px;
    background: #1a1a28;
  }}

  .method-col {{
    background: #0e0e16;
    padding: 10px;
  }}
  .method-col h4 {{
    font-size: 8px;
    letter-spacing: 1.5px;
    text-transform: uppercase;
    margin-bottom: 6px;
    white-space: nowrap;
  }}
  .method-col img {{
    width: 100%;
    border-radius: 4px;
    border: 1px solid #1a1a28;
    cursor: pointer;
    transition: border-color 0.15s;
  }}
  .method-col img:hover {{ border-color: #00e5ff; }}
  .method-meta {{
    display: flex;
    justify-content: space-between;
    margin-top: 4px;
    font-size: 8px;
    color: #404058;
  }}
  .method-meta .cost {{ color: #69f0ae; }}
  .error {{ color: #ef5350; font-size: 9px; padding: 16px; text-align: center; }}

  /* Summary */
  .summary {{
    padding: 24px 32px;
    border-top: 1px solid #1a1a28;
    background: #0c0c12;
  }}
  .summary h2 {{
    font-size: 11px;
    letter-spacing: 3px;
    text-transform: uppercase;
    color: #69f0ae;
    margin-bottom: 12px;
  }}
  .summary-grid {{
    display: grid;
    grid-template-columns: repeat({n_cols}, 1fr);
    gap: 12px;
  }}
  .summary-card {{
    background: #0e0e16;
    border: 1px solid #1a1a28;
    border-radius: 6px;
    padding: 14px;
  }}
  .summary-card h3 {{
    font-size: 9px;
    letter-spacing: 1.5px;
    text-transform: uppercase;
    margin-bottom: 8px;
  }}
  .summary-card .stat {{
    font-size: 18px;
    font-weight: 600;
    margin-bottom: 4px;
  }}
  .summary-card .stat-label {{
    font-size: 9px;
    color: #505068;
  }}

  /* Cost table */
  .cost-table {{
    margin-top: 20px;
    width: 100%;
    border-collapse: collapse;
    font-size: 10px;
  }}
  .cost-table th {{
    text-align: left;
    padding: 8px 12px;
    color: #707088;
    border-bottom: 1px solid #1a1a28;
    font-weight: normal;
    letter-spacing: 1px;
    text-transform: uppercase;
    font-size: 9px;
  }}
  .cost-table td {{
    padding: 6px 12px;
    border-bottom: 1px solid #14141e;
  }}
  .cost-table .total {{ color: #69f0ae; font-weight: 600; }}

  /* Lightbox */
  .lightbox {{
    display: none;
    position: fixed;
    inset: 0;
    background: rgba(0,0,0,0.95);
    z-index: 100;
    align-items: center;
    justify-content: center;
  }}
  .lightbox.open {{ display: flex; }}
  .lightbox img {{
    max-width: 95vw;
    max-height: 95vh;
    object-fit: contain;
    border-radius: 4px;
  }}
  .lightbox .close {{
    position: absolute;
    top: 16px;
    right: 20px;
    font-size: 28px;
    color: #505068;
    cursor: pointer;
    background: none;
    border: none;
    font-family: monospace;
  }}
  .lightbox .close:hover {{ color: #c8c8d4; }}
  .lightbox .nav {{
    position: absolute;
    font-size: 11px;
    color: #505068;
    bottom: 20px;
    left: 50%;
    transform: translateX(-50%);
  }}

  /* Responsive: scroll on narrow screens */
  @media (max-width: 1400px) {{
    .compare {{ overflow-x: auto; grid-template-columns: repeat({n_cols}, minmax(200px, 1fr)); }}
    .summary-grid {{ overflow-x: auto; grid-template-columns: repeat({n_cols}, minmax(160px, 1fr)); }}
  }}
</style>
</head>
<body>

<div class="header">
  <h1>Comprehensive Quality Test — EP{episode:03d}</h1>
  <div class="subtitle">
    <span>Character:</span> {character.title()} &nbsp;|&nbsp;
    <span>Generated:</span> {datetime.now().strftime('%Y-%m-%d %H:%M')} &nbsp;|&nbsp;
    <span>Shots:</span> {len(results)} &nbsp;|&nbsp;
    <span>Seed:</span> {hero_info.get('seed', 42)}
  </div>
  <div class="subtitle" style="margin-top:6px; color:#404058;">
    Testing Z-Image Turbo vs Base (with/without negative prompt) vs ControlNet depth
    vs Gemini wardrobe swap for visual quality and wardrobe consistency.
  </div>
</div>

<div class="hero-section">
  <h2>Wardrobe Hero Frame (Reference Anchor)</h2>
  <div class="hero-content">
"""

    # Hero frame image
    hero_filename = hero_info.get("filename", "")
    hero_elapsed = hero_info.get("elapsed", 0)
    hero_prompt_display = hero_info.get("prompt", "")[:200]

    if hero_filename:
        html += f'    <img src="{hero_filename}" onclick="openLB(this.src)">\n'
    else:
        html += '    <div class="error">Hero frame generation failed</div>\n'

    html += f"""    <div class="hero-meta">
      <div><span class="label">Method:</span> Z-Image Turbo T2I + LoRA (clean reference shot)</div>
      <div><span class="label">Time:</span> {hero_elapsed:.1f}s</div>
      <div><span class="label">Purpose:</span> Used as depth source for ControlNet (D), wardrobe reference for Gemini swap (E)</div>
      <div style="margin-top:8px;"><span class="label">Prompt:</span> {hero_prompt_display}...</div>
    </div>
  </div>
</div>

<div class="legend">
"""

    for mk in show_methods:
        m = METHODS[mk]
        html += f"""  <div class="legend-item">
    <div class="legend-dot" style="background:{m['color']};"></div>
    <span class="legend-label">{m['short']}</span>
    <span class="legend-desc">{m['description'][:60]}</span>
  </div>
"""

    html += "</div>\n\n<div class=\"shots\">\n"

    # Shot rows
    for r in results:
        approach = r.get("generation_approach", "unknown")
        shot_type = r.get("shot_type", "MS")

        html += f"""<div class="shot">
  <div class="shot-header">
    <h3>S{r['shot_id']:02d}: {r['name']}</h3>
    <span class="tag tag-shot-type">{shot_type}</span>
    <span class="tag tag-approach">{approach.replace('_', ' ')}</span>
  </div>
  <div class="shot-prompt" title="Click to expand">{r.get('prompt_preview', '')[:200]}</div>
  <div class="compare">
"""

        for mk in show_methods:
            m = METHODS[mk]
            mr = r.get(mk, {})

            html += f'    <div class="method-col">\n'
            html += f'      <h4 style="color:{m["color"]}">{m["short"]}</h4>\n'

            if mr.get("error"):
                err_short = str(mr["error"])[:100]
                html += f'      <div class="error">{err_short}</div>\n'
            elif mr.get("filename"):
                elapsed = mr.get("elapsed", 0)
                cost = mr.get("cost", 0)
                html += f'      <img src="{mr["filename"]}" onclick="openLB(this.src)" loading="lazy">\n'
                html += f'      <div class="method-meta"><span>{elapsed:.1f}s</span>'
                if cost > 0:
                    html += f'<span class="cost">${cost:.4f}</span>'
                html += '</div>\n'
            else:
                html += '      <div class="error">No result</div>\n'

            html += '    </div>\n'

        html += "  </div>\n</div>\n\n"

    html += "</div>\n"

    # Summary section
    html += '<div class="summary">\n  <h2>Aggregate Results</h2>\n  <div class="summary-grid">\n'

    total_cost = 0
    for mk in show_methods:
        m = METHODS[mk]
        successes = sum(1 for r in results if r.get(mk, {}).get("filename"))
        errors = sum(1 for r in results if r.get(mk, {}).get("error"))
        total_time = sum(r.get(mk, {}).get("elapsed", 0) for r in results
                         if r.get(mk, {}).get("filename"))
        avg_time = total_time / max(successes, 1)
        method_cost = sum(r.get(mk, {}).get("cost", 0) for r in results)
        total_cost += method_cost

        html += f"""    <div class="summary-card">
      <h3 style="color:{m['color']}">{m['short']}</h3>
      <div class="stat" style="color:{m['color']}">{successes}/{len(results)}</div>
      <div class="stat-label">shots generated</div>
      <div class="stat" style="color:{m['color']}; font-size:14px; margin-top:8px;">{avg_time:.1f}s</div>
      <div class="stat-label">avg generation time</div>
      <div class="stat" style="color:#69f0ae; font-size:12px; margin-top:6px;">${method_cost:.3f}</div>
      <div class="stat-label">total cost</div>
    </div>
"""

    html += f"""  </div>

  <table class="cost-table" style="margin-top:20px; max-width:500px;">
    <tr><th>Item</th><th>Cost</th></tr>
    <tr><td>Hero frame</td><td>${hero_info.get('cost', 0.009):.4f}</td></tr>
"""

    for mk in show_methods:
        m = METHODS[mk]
        mc = sum(r.get(mk, {}).get("cost", 0) for r in results)
        html += f'    <tr><td>{m["short"]} ({len(results)} shots)</td><td>${mc:.4f}</td></tr>\n'

    html += f"""    <tr><td class="total">TOTAL</td><td class="total">${total_cost + hero_info.get('cost', 0.009):.4f}</td></tr>
  </table>
</div>

<div class="lightbox" id="lb" onclick="if(event.target===this)closeLB()">
  <button class="close" onclick="closeLB()">&times;</button>
  <img id="lb-img">
  <div class="nav">Click outside or press Escape to close</div>
</div>

<script>
function openLB(src) {{
  document.getElementById('lb-img').src = src;
  document.getElementById('lb').classList.add('open');
}}
function closeLB() {{
  document.getElementById('lb').classList.remove('open');
}}
document.addEventListener('keydown', e => {{
  if (e.key === 'Escape') closeLB();
}});
</script>
</body></html>"""

    html_path = output_dir / f"comprehensive_test_ep_{episode:03d}_{character}.html"
    html_path.write_text(html)
    return html_path


# ── Main ────────────────────────────────────────────────────────────────

def main():
    parser = argparse.ArgumentParser(
        description="Comprehensive visual quality test: 5 generation methods compared"
    )
    parser.add_argument("project_dir", help="Project directory (e.g. leviathan/)")
    parser.add_argument("-e", "--episode", type=int, required=True)
    parser.add_argument("-c", "--character", required=True,
                        help="Character to test (e.g. jinx)")
    parser.add_argument("--shots", help="Comma-separated shot IDs (default: auto-select 5)")
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--control-scale", type=float, default=0.5,
                        help="ControlNet depth scale for method D (default: 0.5)")
    parser.add_argument("--hero-url", help="Skip hero generation, use this CDN URL instead")
    parser.add_argument("--cfg", type=float, default=4.0,
                        help="Guidance scale for Base methods B/C (default: 4.0)")
    parser.add_argument("--methods", default="a,b,c,d,e",
                        help="Comma-separated methods to run (default: a,b,c,d,e)")
    parser.add_argument("--output-name",
                        help="Output subdirectory name (default: comprehensive_test)")
    parser.add_argument("--dry-run", action="store_true")
    args = parser.parse_args()

    active_methods = [m.strip().lower() for m in args.methods.split(",")]

    # Resolve paths
    project_dir = Path(args.project_dir).resolve()
    if not project_dir.is_dir():
        project_dir = Path.cwd() / args.project_dir
    if not project_dir.is_dir():
        print(f"ERROR: Project directory not found: {args.project_dir}")
        sys.exit(1)

    # Load storyboard
    ep_str = str(args.episode).zfill(3)
    sb_path = project_dir / "storyboards" / f"storyboard_ep_{ep_str}.json"
    if not sb_path.is_file():
        print(f"ERROR: Storyboard not found: {sb_path}")
        sys.exit(1)

    try:
        with open(sb_path) as f:
            storyboard = json.load(f)
    except json.JSONDecodeError as e:
        print(f"ERROR: Invalid JSON in {sb_path}: {e}")
        sys.exit(1)

    # Load LoRA registry
    raw_registry = load_registry(project_dir)
    lora_registry = {}
    for char_name in raw_registry:
        lora_registry[char_name] = get_inference_config(raw_registry, char_name)

    char = args.character.lower()
    if char not in lora_registry:
        print(f"ERROR: Character '{char}' not found in LoRA registry")
        sys.exit(1)

    char_config = lora_registry[char]
    char_data = storyboard.get("characters", {}).get(char, {})
    char_visual = char_data.get("visual", "")

    # Prompt engine for structured prompts
    engine = PromptEngine(project_dir, storyboard, lora_registry, model_key="z_image")

    # LoRA paths — Turbo LoRA for method A, Base LoRA for methods B/C
    turbo_lora_path = char_config.get("z_image_t2i_path") or char_config.get("t2i_path")
    base_lora_path = char_config.get("z_image_base_t2i_path") or turbo_lora_path
    solo_scale = char_config.get("scale_solo", 0.9)

    turbo_loras = [{"path": turbo_lora_path, "scale": solo_scale}] if turbo_lora_path else []
    base_loras = [{"path": base_lora_path, "scale": solo_scale}] if base_lora_path else []

    # Fallback: use turbo for everything if no base LoRA
    z_lora_path = turbo_lora_path  # for ControlNet (method D)
    z_loras = turbo_loras  # legacy reference

    trigger = char_config.get("trigger", "")

    # Filter shots — solo shots for this character
    shot_ids = None
    if args.shots:
        shot_ids = [int(s.strip()) for s in args.shots.split(",")]

    shots_to_test = []
    for shot in storyboard.get("shots", []):
        sid = shot.get("id")
        if shot_ids and sid not in shot_ids:
            continue

        chars = [c.lower() for c in shot.get("characters_in_shot", [])]
        if char not in chars:
            continue
        if len(chars) > 1:
            continue

        shots_to_test.append(shot)

    if not shot_ids:
        # Auto-select: pick diverse shot types, max 5
        type_seen = {}
        filtered = []
        for shot in shots_to_test:
            st = shot.get("shot_type", "MS")
            if type_seen.get(st, 0) < 2:
                filtered.append(shot)
                type_seen[st] = type_seen.get(st, 0) + 1
            if len(filtered) >= 5:
                break
        shots_to_test = filtered

    if not shots_to_test:
        print("No matching shots found.")
        sys.exit(0)

    # Output directory
    dir_name = args.output_name or "comprehensive_test"
    output_dir = project_dir / "storyboards" / "assets" / f"ep_{ep_str}" / dir_name
    output_dir.mkdir(parents=True, exist_ok=True)

    # Cost estimate (filtered by active methods)
    n_shots = len(shots_to_test)
    need_hero_est = "d" in active_methods or "e" in active_methods
    hero_cost = PORTRAIT_MP * 0.0085 if need_hero_est else 0
    cost_per_method = {
        "a": PORTRAIT_MP * 0.0085,
        "b": PORTRAIT_MP * 0.012,
        "c": PORTRAIT_MP * 0.012,
        "d": PORTRAIT_MP * 0.0065,
        "e": 0.04,
    }
    total_est = hero_cost + sum(
        n_shots * cost_per_method.get(m, 0) for m in active_methods
    )
    n_active = len(active_methods)
    active_labels = [METHODS.get(f"{m}_{'turbo_t2i' if m=='a' else 'base_t2i' if m=='b' else 'base_neg' if m=='c' else 'controlnet' if m=='d' else 'gemini_swap'}", {}).get("short", m.upper()) for m in active_methods]
    method_label = ", ".join(METHODS[mk]["short"] for mk in METHOD_KEYS if mk[0] in active_methods)

    # Header
    print(f"{'=' * 70}")
    print(f"  COMPREHENSIVE QUALITY TEST — EP{args.episode:03d} — {char.upper()}")
    print(f"{'=' * 70}")
    print(f"  Shots: {n_shots}")
    print(f"  Methods: {n_active} ({method_label})")
    print(f"  Seed: {args.seed}")
    print(f"  Turbo LoRA: {turbo_lora_path or 'NONE'}")
    print(f"  Base LoRA:  {base_lora_path or 'NONE'}")
    print(f"  Base CFG: {args.cfg}")
    print(f"  ControlNet scale: {args.control_scale}")
    print(f"  Output: {output_dir}")
    print(f"  Estimated cost: ${total_est:.2f}")
    print(f"{'=' * 70}")

    if args.dry_run:
        print(f"\n  WARDROBE HERO FRAME: Would generate clean reference shot")
        print(f"  Trigger: {trigger}")
        print()
        for shot in shots_to_test:
            st = shot.get("shot_type", "MS")
            approach = shot.get("generation_approach", "?")
            print(f"  S{shot['id']:02d}: {shot['name']:<30s} [{st}] [{approach}]")

        hero_count = 1 if need_hero_est else 0
        total_gens = n_shots * n_active + hero_count
        hero_label = "1 hero + " if need_hero_est else ""
        print(f"\n  Total: {hero_label}{n_shots} x {n_active} methods = {total_gens} generations")
        print(f"\n  Cost breakdown:")
        print(f"    Hero:           ${hero_cost:.4f}")
        method_names = {"a": "Turbo T2I", "b": "Base T2I", "c": "Base + Neg",
                        "d": "ControlNet", "e": "Gemini swap"}
        for m in active_methods:
            mc = n_shots * cost_per_method[m]
            print(f"    {m.upper()}) {method_names.get(m, m):<14s} ${mc:.4f}  ({n_shots} x ${cost_per_method[m]:.4f})")
        print(f"    ────────────────────────")
        print(f"    TOTAL:          ${total_est:.4f}")
        sys.exit(0)

    # ── Step 1: Generate Wardrobe Hero Frame ──────────────────────────

    hero_info = {}
    need_hero = "d" in active_methods or "e" in active_methods

    if args.hero_url:
        print(f"\n  Using provided hero URL: {args.hero_url[:60]}...")
        hero_info = {
            "url": args.hero_url,
            "filename": None,
            "elapsed": 0,
            "seed": args.seed,
            "prompt": "(provided externally)",
            "cost": 0,
        }
    elif not need_hero:
        print(f"\n  Skipping hero frame (methods D/E not active)")
    else:
        print(f"\n  Generating wardrobe hero frame...", end="", flush=True)

        wardrobe_desc = char_data.get("hair_makeup", "")
        if char_visual:
            wardrobe_desc = char_visual

        hero_prompt = WARDROBE_HERO_PROMPT.format(
            trigger=trigger,
            visual=char_visual,
            wardrobe_desc=wardrobe_desc,
        )

        t0 = time.time()
        try:
            hero_result = generate_turbo_t2i(hero_prompt, z_loras, args.seed, PORTRAIT)
            hero_elapsed = time.time() - t0
            hero_url = get_image_url(hero_result)

            if hero_url:
                hero_filename = f"wardrobe_hero_{char}.png"
                hero_path = output_dir / hero_filename
                if download_image(hero_url, hero_path):
                    print(f" OK ({hero_elapsed:.1f}s)")
                    hero_info = {
                        "url": hero_url,
                        "filename": hero_filename,
                        "elapsed": hero_elapsed,
                        "seed": hero_result.get("seed", args.seed),
                        "prompt": hero_prompt[:200],
                        "cost": PORTRAIT_MP * 0.0085,
                    }
                else:
                    print(f" download failed")
                    sys.exit(1)
            else:
                print(f" no image returned")
                sys.exit(1)
        except Exception as e:
            print(f" ERROR: {e}")
            sys.exit(1)

    hero_cdn_url = hero_info["url"]

    # ── Step 2: Generate Each Shot x 5 Methods ──────────────────────

    all_results = []

    for i, shot in enumerate(shots_to_test):
        sid = shot["id"]
        name = shot.get("name", "")
        shot_type = shot.get("shot_type", "MS")
        approach = shot.get("generation_approach", "unknown")

        print(f"\n[{i+1}/{n_shots}] S{sid:02d}: {name} [{shot_type}]")

        # Build prompt using the prompt engine
        z_prompt = engine.compile(shot, "hero")

        result_entry = {
            "shot_id": sid,
            "name": name,
            "shot_type": shot_type,
            "generation_approach": approach,
            "prompt_preview": z_prompt[:200],
            "seed": args.seed,
        }

        # ── Method A: Z-Image Turbo T2I + LoRA ──

        print(f"    A) Turbo T2I + LoRA...", end="", flush=True)
        t0 = time.time()
        try:
            res = generate_turbo_t2i(z_prompt, turbo_loras, args.seed, PORTRAIT)
            elapsed = time.time() - t0
            url = get_image_url(res)
            cost = PORTRAIT_MP * 0.0085
            if url:
                fn = f"S{sid:02d}_a_turbo_t2i.png"
                if download_image(url, output_dir / fn):
                    print(f" OK ({elapsed:.1f}s)")
                    result_entry["a_turbo_t2i"] = {
                        "filename": fn, "elapsed": elapsed,
                        "seed": res.get("seed", args.seed),
                        "cost": cost, "cdn_url": url,
                    }
                else:
                    result_entry["a_turbo_t2i"] = {"error": "download failed", "elapsed": elapsed}
            else:
                print(f" no image")
                result_entry["a_turbo_t2i"] = {"error": "no image returned", "elapsed": elapsed}
        except Exception as e:
            elapsed = time.time() - t0
            print(f" ERROR: {e}")
            result_entry["a_turbo_t2i"] = {"error": str(e), "elapsed": elapsed}

        # ── Method B: Z-Image Base T2I + LoRA ──

        print(f"    B) Base T2I + LoRA...", end="", flush=True)
        t0 = time.time()
        try:
            res = generate_base_t2i(z_prompt, base_loras, args.seed, PORTRAIT,
                                    guidance_scale=args.cfg)
            elapsed = time.time() - t0
            url = get_image_url(res)
            cost = PORTRAIT_MP * 0.012
            if url:
                fn = f"S{sid:02d}_b_base_t2i.png"
                if download_image(url, output_dir / fn):
                    print(f" OK ({elapsed:.1f}s)")
                    result_entry["b_base_t2i"] = {
                        "filename": fn, "elapsed": elapsed,
                        "seed": res.get("seed", args.seed),
                        "cost": cost,
                    }
                else:
                    result_entry["b_base_t2i"] = {"error": "download failed", "elapsed": elapsed}
            else:
                print(f" no image")
                result_entry["b_base_t2i"] = {"error": "no image returned", "elapsed": elapsed}
        except Exception as e:
            elapsed = time.time() - t0
            print(f" ERROR: {e}")
            result_entry["b_base_t2i"] = {"error": str(e), "elapsed": elapsed}

        # ── Method C: Z-Image Base T2I + LoRA + Negative Prompt ──

        print(f"    C) Base + Neg prompt...", end="", flush=True)
        t0 = time.time()
        try:
            res = generate_base_t2i(z_prompt, base_loras, args.seed, PORTRAIT,
                                    negative_prompt=NEGATIVE_PROMPT,
                                    guidance_scale=args.cfg)
            elapsed = time.time() - t0
            url = get_image_url(res)
            cost = PORTRAIT_MP * 0.012
            if url:
                fn = f"S{sid:02d}_c_base_neg.png"
                if download_image(url, output_dir / fn):
                    print(f" OK ({elapsed:.1f}s)")
                    result_entry["c_base_neg"] = {
                        "filename": fn, "elapsed": elapsed,
                        "seed": res.get("seed", args.seed),
                        "cost": cost,
                    }
                else:
                    result_entry["c_base_neg"] = {"error": "download failed", "elapsed": elapsed}
            else:
                print(f" no image")
                result_entry["c_base_neg"] = {"error": "no image returned", "elapsed": elapsed}
        except Exception as e:
            elapsed = time.time() - t0
            print(f" ERROR: {e}")
            result_entry["c_base_neg"] = {"error": str(e), "elapsed": elapsed}

        # ── Method D: Z-Image ControlNet (depth) + LoRA ──

        if "d" in active_methods:
            print(f"    D) ControlNet depth (scale={args.control_scale})...", end="", flush=True)
            t0 = time.time()
            try:
                res = generate_controlnet_depth(
                    z_prompt, turbo_loras, args.seed, PORTRAIT,
                    hero_cdn_url, args.control_scale,
                )
                elapsed = time.time() - t0
                url = get_image_url(res)
                cost = PORTRAIT_MP * 0.0065
                if url:
                    fn = f"S{sid:02d}_d_controlnet.png"
                    if download_image(url, output_dir / fn):
                        print(f" OK ({elapsed:.1f}s)")
                        result_entry["d_controlnet"] = {
                            "filename": fn, "elapsed": elapsed,
                            "seed": res.get("seed", args.seed),
                            "cost": cost,
                        }
                    else:
                        result_entry["d_controlnet"] = {"error": "download failed", "elapsed": elapsed}
                else:
                    print(f" no image")
                    result_entry["d_controlnet"] = {"error": "no image returned", "elapsed": elapsed}
            except Exception as e:
                elapsed = time.time() - t0
                print(f" ERROR: {e}")
                result_entry["d_controlnet"] = {"error": str(e), "elapsed": elapsed}

        # ── Method E: Gemini Wardrobe Swap ──

        if "e" in active_methods:
            # Method E uses Method A's output as source, then swaps wardrobe via Gemini
            source_url = result_entry.get("a_turbo_t2i", {}).get("cdn_url")
            if not source_url:
                print(f"    E) Gemini swap... SKIP (no method A output)")
                result_entry["e_gemini_swap"] = {"error": "no source image (method A failed)"}
            else:
                print(f"    E) Gemini swap...", end="", flush=True)
                t0 = time.time()
                try:
                    res = generate_gemini_swap(source_url, hero_cdn_url, char)
                    elapsed = time.time() - t0
                    url = get_image_url(res)
                    cost = 0.04
                    if url:
                        fn = f"S{sid:02d}_e_gemini.png"
                        if download_image(url, output_dir / fn):
                            print(f" OK ({elapsed:.1f}s)")
                            result_entry["e_gemini_swap"] = {
                                "filename": fn, "elapsed": elapsed,
                                "cost": cost,
                            }
                        else:
                            result_entry["e_gemini_swap"] = {"error": "download failed", "elapsed": elapsed}
                    else:
                        print(f" no image")
                        result_entry["e_gemini_swap"] = {"error": "no image returned", "elapsed": elapsed}
                except Exception as e:
                    elapsed = time.time() - t0
                    print(f" ERROR: {e}")
                    result_entry["e_gemini_swap"] = {"error": str(e), "elapsed": elapsed}

        all_results.append(result_entry)

    # ── Step 3: Build HTML Report ─────────────────────────────────────

    html_path = build_comparison_html(
        all_results, hero_info, output_dir, args.episode, char,
        active_methods=active_methods,
    )

    # Write results JSON
    json_path = output_dir / f"comprehensive_test_ep_{args.episode:03d}_{char}.json"
    with open(json_path, "w") as f:
        json.dump({
            "episode": args.episode,
            "character": char,
            "seed": args.seed,
            "control_scale": args.control_scale,
            "hero_info": hero_info,
            "generated_at": datetime.now().isoformat(),
            "methods": {k: v["label"] for k, v in METHODS.items()},
            "results": all_results,
        }, f, indent=2)

    # Summary
    print(f"\n{'=' * 70}")
    print(f"  RESULTS")
    print(f"{'=' * 70}")
    actual_cost = hero_info.get("cost", 0.009)
    for mk in [mk for mk in METHOD_KEYS if mk[0] in active_methods]:
        m = METHODS[mk]
        successes = sum(1 for r in all_results if r.get(mk, {}).get("filename"))
        errors = sum(1 for r in all_results if r.get(mk, {}).get("error"))
        times = [r.get(mk, {}).get("elapsed", 0) for r in all_results
                 if r.get(mk, {}).get("filename")]
        avg = sum(times) / max(len(times), 1)
        mc = sum(r.get(mk, {}).get("cost", 0) for r in all_results)
        actual_cost += mc
        print(f"  {m['label']:35s}  {successes}/{n_shots} OK  avg {avg:.1f}s  ${mc:.3f}")
    print(f"\n  Total cost: ${actual_cost:.3f}")
    print(f"\n  HTML: {html_path}")
    print(f"  JSON: {json_path}")
    print(f"{'=' * 70}")


if __name__ == "__main__":
    main()
