#!/usr/bin/env python3
"""
Previz Calibration Test.

Generates the same 5 shots at different resolution/step combos to find the
quality floor for previz. Uses Z-Image Turbo + LoRA (if available).

Tests:
  A) 512x512  @ 4 steps  (absolute minimum)
  B) 512x512  @ 8 steps  (turbo standard)
  C) 512x896  @ 8 steps  (9:16 low-res, correct aspect)
  D) 576x1024 @ 8 steps  (production resolution, turbo steps)
  E) 768x1344 @ 8 steps  (full production — baseline comparison)
  F) 512x896  @ 4 steps  (9:16 speed floor)

Outputs a contact sheet per test and an HTML comparison page.

Usage:
  python3 previz_calibration_test.py <project_path> --episode 1
  python3 previz_calibration_test.py <project_path> --episode 1 --shots 1,3,11,16,19

Requires: FAL_KEY environment variable
"""

import argparse
import json
import os
import sys
import time
from datetime import datetime
from pathlib import Path

try:
    import fal_client
except ImportError:
    print("ERROR: pip install fal-client")
    sys.exit(2)

try:
    from PIL import Image
    from io import BytesIO
    import requests
except ImportError:
    print("ERROR: pip install Pillow requests")
    sys.exit(2)

# Add tools dir for imports
_tools_dir = str(Path(__file__).resolve().parent)
if _tools_dir not in sys.path:
    sys.path.insert(0, _tools_dir)

from train_lora import load_registry, get_inference_config
from cost_tracker import CostTracker


# ── Previz Standard (decided 2026-02-09 via calibration testing) ──
PREVIZ_WIDTH = 512
PREVIZ_HEIGHT = 896
PREVIZ_STEPS = 8

TEST_CONFIGS = [
    {"label": "A_512x512_4step", "width": 512, "height": 512, "steps": 4},
    {"label": "B_512x512_8step", "width": 512, "height": 512, "steps": 8},
    {"label": "C_512x896_8step", "width": PREVIZ_WIDTH, "height": PREVIZ_HEIGHT, "steps": PREVIZ_STEPS},
    {"label": "D_576x1024_8step", "width": 576, "height": 1024, "steps": 8},
    {"label": "E_768x1344_8step", "width": 768, "height": 1344, "steps": 8},
    {"label": "F_512x896_4step", "width": 512, "height": 896, "steps": 4},
    {"label": "G_384x672_4step", "width": 384, "height": 672, "steps": 4},
    {"label": "H_512x896_2step", "width": 512, "height": 896, "steps": 2},
    {"label": "I_512x896_3step", "width": 512, "height": 896, "steps": 3},
]


def generate_frame(prompt, width, height, steps, lora_url=None, lora_trigger=None, seed=42):
    """Generate a single frame via Z-Image Turbo."""
    # Prepend LoRA trigger if available
    full_prompt = prompt
    if lora_trigger:
        full_prompt = f"{lora_trigger} {prompt}"

    payload = {
        "prompt": full_prompt,
        "image_size": {"width": width, "height": height},
        "num_inference_steps": steps,
        "seed": seed,
        "num_images": 1,
        "output_format": "jpeg",
        "enable_safety_checker": False,
    }

    if lora_url:
        payload["loras"] = [{"path": lora_url, "scale": 1.0}]

    endpoint = "fal-ai/z-image/turbo/lora"

    start = time.time()
    result = fal_client.subscribe(endpoint, arguments=payload)
    elapsed_ms = int((time.time() - start) * 1000)

    if result and result.get("images"):
        img_url = result["images"][0].get("url")
        return img_url, elapsed_ms

    return None, elapsed_ms


def download_image(url):
    """Download image from URL and return PIL Image."""
    resp = requests.get(url, timeout=30)
    resp.raise_for_status()
    return Image.open(BytesIO(resp.content))


def make_contact_sheet(images, labels, output_path, title=""):
    """Create a contact sheet from a list of PIL Images with labels."""
    if not images:
        return

    # Normalize all images to same height for side-by-side
    target_h = 400
    resized = []
    for img in images:
        ratio = target_h / img.height
        new_w = int(img.width * ratio)
        resized.append(img.resize((new_w, target_h), Image.LANCZOS))

    # Calculate total width
    padding = 10
    total_w = sum(img.width for img in resized) + padding * (len(resized) + 1)
    total_h = target_h + 60  # Space for label

    sheet = Image.new("RGB", (total_w, total_h), (30, 30, 30))

    x_offset = padding
    for img, label in zip(resized, labels):
        sheet.paste(img, (x_offset, 10))
        x_offset += img.width + padding

    sheet.save(output_path, quality=90)
    return output_path


def write_comparison_html(all_results, output_dir, episode):
    """Write an HTML comparison page showing all configs side by side."""
    html_path = output_dir / "calibration_comparison.html"

    # Collect unique config labels in order they appear
    seen_labels = []
    for r in all_results:
        if r["config_label"] not in seen_labels:
            seen_labels.append(r["config_label"])

    # Group by shot
    shots = {}
    for r in all_results:
        sid = r["shot_id"]
        if sid not in shots:
            shots[sid] = {"name": r["shot_name"], "configs": []}
        shots[sid]["configs"].append(r)

    rows_html = ""
    for sid in sorted(shots.keys()):
        shot = shots[sid]
        cells = ""
        for config in shot["configs"]:
            img_path = config.get("local_path", "")
            rel_path = os.path.relpath(img_path, output_dir) if img_path else ""
            ms = config.get("duration_ms", 0)
            label = config["config_label"]
            cells += f"""
            <td style="text-align:center; padding:8px;">
                <img src="{rel_path}" style="max-width:200px; max-height:350px; border:1px solid #555;">
                <br><small>{label}<br>{ms}ms</small>
            </td>"""

        rows_html += f"""
        <tr>
            <td style="padding:8px; font-weight:bold; vertical-align:top;">
                #{sid}<br><small>{shot['name']}</small>
            </td>
            {cells}
        </tr>"""

    html = f"""<!DOCTYPE html>
<html>
<head>
    <title>Previz Calibration — Episode {episode}</title>
    <style>
        body {{ background: #1a1a1a; color: #eee; font-family: system-ui; padding: 20px; }}
        table {{ border-collapse: collapse; }}
        th {{ padding: 10px; text-align: center; border-bottom: 2px solid #555; }}
        td {{ border-bottom: 1px solid #333; }}
        h1 {{ color: #f0f0f0; }}
        .note {{ color: #999; font-size: 13px; margin-top: 20px; }}
    </style>
</head>
<body>
    <h1>Previz Calibration — Episode {episode}</h1>
    <p>Generated {datetime.now().strftime('%Y-%m-%d %H:%M')} | Z-Image Turbo + LoRA</p>
    <table>
        <tr>
            <th>Shot</th>
            {''.join(f'<th>{label}</th>' for label in seen_labels)}
        </tr>
        {rows_html}
    </table>
    <p class="note">
        Goal: find the quality floor where you can tell if a shot "works" compositionally
        and emotionally. Lower settings = cheaper previz. Look for: character recognizability,
        composition readability, emotional tone, framing accuracy.
    </p>
</body>
</html>"""

    with open(html_path, "w") as f:
        f.write(html)

    return html_path


def main():
    parser = argparse.ArgumentParser(description="Previz calibration test")
    parser.add_argument("project", help="Project path")
    parser.add_argument("--episode", type=int, required=True, help="Episode number")
    parser.add_argument(
        "--shots",
        help="Comma-separated shot IDs to test (default: 5 evenly spaced)",
    )
    parser.add_argument("--seed", type=int, default=42, help="Seed (default: 42)")
    parser.add_argument(
        "--configs",
        help="Comma-separated config labels to run (e.g. F_512x896_4step). Default: all",
    )

    args = parser.parse_args()

    project_dir = Path(args.project)

    # Load storyboard
    sb_path = project_dir / "storyboards" / f"storyboard_ep_{args.episode:03d}.json"
    if not sb_path.exists():
        print(f"ERROR: Storyboard not found: {sb_path}")
        sys.exit(2)

    try:
        with open(sb_path) as f:
            storyboard = json.load(f)
    except json.JSONDecodeError as e:
        print(f"ERROR: Invalid JSON in {sb_path}: {e}")
        sys.exit(2)

    shots = storyboard.get("shots", [])

    # Select shots to test
    if args.shots:
        shot_ids = [int(s) for s in args.shots.split(",")]
    else:
        # Pick 5 evenly spaced shots
        n = len(shots)
        indices = [0, n // 4, n // 2, 3 * n // 4, n - 1]
        shot_ids = [shots[i]["id"] for i in indices if i < n]

    selected_shots = [s for s in shots if s["id"] in shot_ids]
    if not selected_shots:
        print("ERROR: No matching shots found")
        sys.exit(2)

    # Load LoRA registry
    lora_registry = {}
    try:
        lora_registry = load_registry(project_dir)
    except Exception as e:
        print(f"  WARNING: LoRA registry load failed: {e}")

    # Filter configs if specified
    configs_to_run = TEST_CONFIGS
    if args.configs:
        labels = [l.strip() for l in args.configs.split(",")]
        configs_to_run = [c for c in TEST_CONFIGS if c["label"] in labels]
        if not configs_to_run:
            print(f"ERROR: No matching configs. Available: {[c['label'] for c in TEST_CONFIGS]}")
            sys.exit(2)

    # Cost tracker
    tracker = CostTracker(project_dir)

    # Output directory
    output_dir = project_dir / "production" / f"episode_{args.episode:03d}" / "previz_calibration"
    output_dir.mkdir(parents=True, exist_ok=True)

    print(f"=== Previz Calibration Test ===")
    print(f"Project: {project_dir.name}")
    print(f"Episode: {args.episode}")
    print(f"Shots: {[s['id'] for s in selected_shots]}")
    print(f"Configs: {len(configs_to_run)} — {[c['label'] for c in configs_to_run]}")
    print(f"Total generations: {len(selected_shots) * len(configs_to_run)}")
    print(f"LoRA registry: {len(lora_registry)} characters loaded" if lora_registry else "LoRA registry: empty")
    print(f"Output: {output_dir}")
    print()

    all_results = []

    for config in configs_to_run:
        print(f"--- Config: {config['label']} ---")

        for shot in selected_shots:
            # Get the first character's LoRA if available
            lora_url = None
            lora_trigger = None
            chars = shot.get("characters_in_shot", [])
            if chars and lora_registry:
                char_name = chars[0]
                try:
                    lora_config = get_inference_config(lora_registry, char_name)
                    # Prefer z_image LoRA for turbo endpoint, fall back to flux2 t2i
                    lora_url = (
                        lora_config.get("z_image_t2i_path")
                        or lora_config.get("z_image_base_t2i_path")
                        or lora_config.get("t2i_path")
                    )
                    lora_trigger = lora_config.get("trigger")
                except Exception:
                    pass

            # Use the first_frame prompt
            prompt = shot.get("first_frame", shot.get("subject", ""))
            if not prompt:
                print(f"  Shot #{shot['id']}: no prompt, skipping")
                continue

            lora_tag = f"+LoRA({chars[0]})" if lora_url else "no-LoRA"
            print(f"  Shot #{shot['id']} '{shot['name']}' [{lora_tag}]... ", end="", flush=True)

            try:
                img_url, elapsed_ms = generate_frame(
                    prompt,
                    config["width"],
                    config["height"],
                    config["steps"],
                    lora_url=lora_url,
                    lora_trigger=lora_trigger,
                    seed=args.seed,
                )

                if img_url:
                    # Download and save locally
                    img = download_image(img_url)
                    filename = f"shot_{shot['id']:02d}_{config['label']}.jpg"
                    local_path = output_dir / filename
                    img.save(str(local_path), quality=90)

                    result = {
                        "shot_id": shot["id"],
                        "shot_name": shot.get("name", ""),
                        "config_label": config["label"],
                        "width": config["width"],
                        "height": config["height"],
                        "steps": config["steps"],
                        "duration_ms": elapsed_ms,
                        "local_path": str(local_path),
                        "lora_applied": lora_url is not None,
                    }
                    all_results.append(result)
                    tracker.log(
                        category="generation",
                        provider="fal",
                        model="z_image_turbo",
                        resolution=f"{config['width']}x{config['height']}",
                        loras=1 if lora_url else 0,
                        steps=config["steps"],
                        episode=args.episode,
                        shot_id=shot["id"],
                        duration_ms=elapsed_ms,
                        success=True,
                        detail=f"previz calibration {config['label']}",
                    )
                    print(f"OK ({elapsed_ms}ms)")
                else:
                    tracker.log(
                        category="generation",
                        provider="fal",
                        model="z_image_turbo",
                        resolution=f"{config['width']}x{config['height']}",
                        loras=1 if lora_url else 0,
                        steps=config["steps"],
                        episode=args.episode,
                        shot_id=shot["id"],
                        duration_ms=elapsed_ms,
                        success=False,
                        detail=f"previz calibration {config['label']}: no image returned",
                    )
                    print("FAILED (no image returned)")

            except Exception as e:
                tracker.log(
                    category="generation",
                    provider="fal",
                    model="z_image_turbo",
                    resolution=f"{config['width']}x{config['height']}",
                    loras=1 if lora_url else 0,
                    steps=config["steps"],
                    episode=args.episode,
                    shot_id=shot["id"],
                    success=False,
                    detail=f"previz calibration {config['label']}: {e}",
                )
                print(f"ERROR: {e}")

    # Write comparison HTML
    if all_results:
        html_path = write_comparison_html(all_results, output_dir, args.episode)
        print()
        print(f"=== Results ===")
        print(f"  {len(all_results)} frames generated")
        print(f"  Comparison: {html_path}")
        print(f"  Open in browser: file://{html_path}")

        # Write results JSON
        results_json = output_dir / "calibration_results.json"
        with open(results_json, "w") as f:
            json.dump(
                {
                    "episode": args.episode,
                    "shots_tested": [s["id"] for s in selected_shots],
                    "configs": TEST_CONFIGS,
                    "results": all_results,
                    "generated_at": datetime.now().isoformat(),
                },
                f,
                indent=2,
            )
        print(f"  Results JSON: {results_json}")
    else:
        print("No frames generated.")


if __name__ == "__main__":
    main()
