#!/usr/bin/env python3
"""
Prompt A/B Test Harness — Kling Prompt Strategy Validation

Tests prompt engineering hypotheses before committing to full system build.
Separates cheap validation (prompt generation) from expensive validation
(Kling generation) so you only spend money on the most promising variants.

Usage:
    # Phase 1: Generate and compare prompts (costs ~$0.50)
    python3 tools/prompt_ab_test.py --project tartarus --phase prompts

    # Phase 1 with specific shots
    python3 tools/prompt_ab_test.py --project tartarus --shots EP001_SH32,EP001_SH34 --phase prompts

    # Phase 3: Generate selected variants via StepRunner (costs $$$)
    python3 tools/prompt_ab_test.py --project tartarus --shots EP001_SH34 --variants D,F --phase generate

    # Phase 3 with model comparison (V3 vs O3)
    python3 tools/prompt_ab_test.py --project tartarus --shots EP001_SH34 --variants D --models v3,o3 --phase generate

    # Phase 3 with negative prompt test
    python3 tools/prompt_ab_test.py --project tartarus --shots EP001_SH34 --variants D --test-negative --phase generate

    # Phase 3 with audio toggle test
    python3 tools/prompt_ab_test.py --project tartarus --shots EP001_SH34 --variants D --test-audio --phase generate

    # Phase 3 batch test (2-shot multi_prompt vs individual)
    python3 tools/prompt_ab_test.py --project tartarus --shots EP001_SH32,EP001_SH34 --variants D --test-batch --phase generate
"""

import argparse
import json
import os
import sys
import time
from datetime import datetime, timezone
from pathlib import Path

# Add project root to path
PROJECT_ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(PROJECT_ROOT))

from recoil.core.paths import projects_root, ProjectPaths
from recoil.pipeline._lib.verb_calibration import calibrate_verbs

try:
    from recoil.pipeline._lib.prompt_engine import build_kling_i2v_prompt as _production_prompt_builder
except ImportError:
    _production_prompt_builder = None

# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------

NEGATIVE_PROMPT_ACTION = (
    "morphing body parts, extra limbs, body merging, frozen motion, "
    "static pose, background spinning, distorted limbs, extra fingers, "
    "duplicate body parts, face distortion, watermark"
)

# Default test shots — covers range from env-only to 2-character heavy action
DEFAULT_TEST_SHOTS = ["EP001_SH32", "EP001_SH34", "EP001_SH36"]

# The I2V system prompt (from Opus R3, consultation winner).
# Same prompt used for all enrichment models to isolate model capability.
I2V_SYSTEM_PROMPT = """You are a video prompt engineer for Kling's Image-to-Video model.

INPUT: A narrative shot description and a reference image.
OUTPUT: A motion-only prompt of 30 words or fewer.

RULES — follow every one without exception:

1. WORD CAP: 30 words maximum. Count every word. If over 30, cut environment and lighting first, then adjectives. Never cut the primary action verb or camera direction.

2. CAMERA FIRST: Begin with camera movement if specified (e.g., "Camera pushes in slowly," "Static shot," "Slow dolly left"). If no camera movement, begin with the character's action.

3. ONE ACTION ONLY: Describe exactly one continuous physical action. If the narrative describes multiple actions, keep only the most visually important one. Never chain actions with "and" or commas.

4. TEMPORAL MARKERS: Use "Initially... then..." ONLY when the single action has a clear two-phase motion (e.g., "Initially reaches forward, then pulls back"). Do not force temporal markers onto simple actions.

5. MOTION ENDPOINTS: Every action must imply a start state and an end state. Bad: "Wren moves." Good: "Wren's arm extends from her side to grip the railing." The model needs to know WHERE motion begins and WHERE it resolves.

6. VERB CALIBRATION: Use restrained, physics-grounded verbs. No: "explodes," "slams," "whips," "rockets." Yes: "steps," "lifts," "turns," "reaches," "lowers," "shifts." The input will already have verbs calibrated — preserve them exactly.

7. NO VISUAL DESCRIPTIONS: Do not describe what the character looks like, what they are wearing, hair color, eye color, skin tone, or any static visual attribute. The reference image carries all identity. Do not describe environment unless it is physically interacting with the action (e.g., "grips the metal railing").

8. CHARACTER NAMES: Use character names ("Wren," "Torch"), never pronouns ("she," "he," "they").

9. NO BRACKET NOTATION: Do not use [brackets], {braces}, or any structured markup. Write plain natural English.

10. NO STYLE OR QUALITY LANGUAGE: Do not append "cinematic," "photorealistic," "high quality," "4K," or similar.

EXAMPLES:

Input: "Wren is terrified and sprints down the corridor, sparks flying from damaged panels on the walls, emergency red lighting pulses"
Output: "Camera tracks laterally. Wren runs steadily down the corridor, arms pumping, moving from background toward camera."

Input: "Close-up of Torch's hands carefully gripping Wren's shoulder plate, pulling her upward through the hatch"
Output: "Static close-up. Torch's hands tighten on the shoulder plate, then lift upward steadily from waist height to chest height."

Input: "Wren turns to look at something off-screen, fear on her face"
Output: "Camera holds static. Wren's head turns slowly from facing forward to looking frame-left, shoulders tensing."
"""


# ---------------------------------------------------------------------------
# Plan data loading
# ---------------------------------------------------------------------------

def load_plan_shots(project: str, shot_ids: list[str]) -> dict[str, dict]:
    """Load shot data from the project's plan JSON."""
    plan_dir = ProjectPaths.for_project(project).plans_dir

    shots = {}
    for plan_file in sorted(plan_dir.glob("*_plan.json")):
        with open(plan_file) as f:
            plan = json.load(f)
        for shot in plan.get("shots", []):
            if shot["shot_id"] in shot_ids:
                shots[shot["shot_id"]] = shot
        if len(shots) == len(shot_ids):
            break

    missing = set(shot_ids) - set(shots.keys())
    if missing:
        print(f"WARNING: Shots not found in plan data: {missing}")

    return shots


# ---------------------------------------------------------------------------
# Prompt building (mirrors prompt_engine.py but isolated for testing)
# ---------------------------------------------------------------------------

def build_baseline_prompt(shot: dict, word_cap: int = 40) -> str:
    """Build baseline I2V prompt matching production prompt_engine.py logic exactly."""
    pd = shot.get("prompt_data", {})
    skeleton = pd.get("prompt_skeleton", {})

    shot_type = pd.get("shot_type", "MS")
    camera_movement = pd.get("camera_movement", "static")

    sentences = []

    # Shot type expansion — matches prompt_engine.py type_names dict
    type_names = {
        "ECU": "Extreme close-up", "CU": "Close-up", "BCU": "Big close-up",
        "MCU": "Medium close-up", "MS": "Medium shot", "LS": "Long shot",
        "WS": "Wide shot", "FS": "Full shot", "INSERT": "Insert",
    }
    shot_label = type_names.get(shot_type, shot_type)

    # Camera movement natural language — matches prompt_engine.py movement_names dict
    if camera_movement and camera_movement != "static":
        movement_names = {
            "pan": "panning", "tilt": "tilting", "push_in": "pushing in",
            "pull_back": "pulling back", "tracking": "tracking",
            "crane": "crane", "handheld": "handheld", "steadicam": "Steadicam",
        }
        move = movement_names.get(camera_movement, camera_movement)
        sentences.append(f"{shot_label}, {move}.")
    else:
        sentences.append(f"{shot_label}.")

    # Action line
    action = skeleton.get("action_line", "")
    if action:
        sentences.append(action.strip().rstrip(".") + ".")

    # Kinetic action
    kinetic = pd.get("kinetic_action", "")
    if kinetic:
        sentences.append(kinetic.strip().rstrip(".") + ".")

    # Director notes — production reads from shot top-level, NOT prompt_data
    director_notes = shot.get("director_notes", "")
    if director_notes and director_notes.strip():
        sentences.append(director_notes.strip().rstrip(".") + ".")

    prompt = " ".join(s for s in sentences if s)

    # Word cap with clean sentence-boundary truncation (matches prompt_engine.py)
    words = prompt.split()
    if len(words) > word_cap:
        words = words[:word_cap]
        prompt = " ".join(words)
        if "." in prompt:
            prompt = prompt[:prompt.rfind(".") + 1]
        else:
            prompt = prompt.rstrip(".,") + "."
    else:
        prompt = " ".join(words)

    return prompt


def build_enrichment_input(shot: dict) -> str:
    """Build the raw input text for enrichment models."""
    pd = shot.get("prompt_data", {})
    skeleton = pd.get("prompt_skeleton", {})
    ad = shot.get("audio_data", {})

    parts = []
    parts.append(f"Shot type: {pd.get('shot_type', 'MS')}")
    if pd.get("camera_movement"):
        parts.append(f"Camera: {pd['camera_movement']}")
    if skeleton.get("subject_line"):
        parts.append(f"Subject: {skeleton['subject_line']}")
    if skeleton.get("action_line"):
        parts.append(f"Action: {skeleton['action_line']}")
    if pd.get("kinetic_action"):
        parts.append(f"Kinetic: {pd['kinetic_action']}")
    if skeleton.get("emotion_line"):
        parts.append(f"Emotion: {skeleton['emotion_line']}")
    # Director notes — top-level in plan data, not inside prompt_data
    director_notes = shot.get("director_notes", "")
    if director_notes:
        parts.append(f"Director notes: {director_notes}")
    if ad.get("foley_action"):
        parts.append(f"Foley: {ad['foley_action']}")

    return "\n".join(parts)


# ---------------------------------------------------------------------------
# Enrichment model calls
# ---------------------------------------------------------------------------

def enrich_with_flash(raw_input: str, system_prompt: str) -> tuple[str, float]:
    """Call Gemini Flash for enrichment. Returns (enriched_text, latency_ms)."""
    try:
        from google import genai
    except ImportError:
        return "[SKIP: google-genai not installed]", 0.0

    api_key = os.environ.get("GOOGLE_API_KEY") or os.environ.get("GEMINI_API_KEY")
    if not api_key:
        return "[SKIP: no GOOGLE_API_KEY or GEMINI_API_KEY]", 0.0

    client = genai.Client(api_key=api_key)
    start = time.time()
    try:
        response = client.models.generate_content(
            model="gemini-2.5-flash",
            contents=raw_input,
            config=genai.types.GenerateContentConfig(
                system_instruction=system_prompt,
                temperature=0.3,
            ),
        )
        latency = (time.time() - start) * 1000
        return response.text.strip(), latency
    except Exception as e:
        return f"[ERROR: {e}]", 0.0


def enrich_with_anthropic(raw_input: str, system_prompt: str, model: str) -> tuple[str, float]:
    """Call Anthropic model for enrichment. Returns (enriched_text, latency_ms)."""
    try:
        import anthropic
    except ImportError:
        return "[SKIP: anthropic SDK not installed]", 0.0

    api_key = os.environ.get("ANTHROPIC_API_KEY")
    if not api_key:
        return "[SKIP: no ANTHROPIC_API_KEY]", 0.0

    client = anthropic.Anthropic(api_key=api_key)
    start = time.time()
    try:
        response = client.messages.create(
            model=model,
            max_tokens=200,
            temperature=0.3,
            system=system_prompt,
            messages=[{"role": "user", "content": raw_input}],
        )
        latency = (time.time() - start) * 1000
        return response.content[0].text.strip(), latency
    except Exception as e:
        return f"[ERROR: {e}]", 0.0


# ---------------------------------------------------------------------------
# Variant generation
# ---------------------------------------------------------------------------

VARIANT_LABELS = {
    "A": "Baseline (40w, no verb cal, no enrichment)",
    "B": "Verb calibration only (40w)",
    "C": "Verb cal + 30-word cap",
    "D": "Verb cal + 30w + Flash enrichment",
    "E": "Verb cal + 30w + Sonnet enrichment",
    "F": "Verb cal + 30w + Opus enrichment",
}


def generate_variants(shot: dict) -> dict[str, dict]:
    """Generate all 6 prompt variants for a shot.

    Returns dict of {variant_id: {prompt, word_count, enrichment_model,
    enrichment_latency_ms, has_temporal_markers, has_motion_endpoints}}.
    """
    shot_id = shot["shot_id"]
    results = {}

    # --- A: Baseline (production prompt builder) ---
    if _production_prompt_builder:
        prompt_a = _production_prompt_builder(shot)
    else:
        prompt_a = build_baseline_prompt(shot, word_cap=40)
    results["A"] = _analyze_prompt(prompt_a, model=None, latency=0.0)

    # --- B: Verb calibration only, 40-word cap ---
    prompt_b = calibrate_verbs(build_baseline_prompt(shot, word_cap=40))
    results["B"] = _analyze_prompt(prompt_b, model=None, latency=0.0)

    # --- C: Verb cal + 30-word cap ---
    prompt_c = calibrate_verbs(build_baseline_prompt(shot, word_cap=30))
    results["C"] = _analyze_prompt(prompt_c, model=None, latency=0.0)

    # --- D, E, F: Enrichment variants ---
    # Build the enrichment input from plan data, apply verb cal first
    raw_input = calibrate_verbs(build_enrichment_input(shot))

    # D: Flash
    enriched_d, latency_d = enrich_with_flash(raw_input, I2V_SYSTEM_PROMPT)
    results["D"] = _analyze_prompt(enriched_d, model="gemini-2.5-flash", latency=latency_d)

    # E: Sonnet
    enriched_e, latency_e = enrich_with_anthropic(raw_input, I2V_SYSTEM_PROMPT, "claude-sonnet-4-6")
    results["E"] = _analyze_prompt(enriched_e, model="claude-sonnet-4-6", latency=latency_e)

    # F: Opus
    enriched_f, latency_f = enrich_with_anthropic(raw_input, I2V_SYSTEM_PROMPT, "claude-opus-4-6")
    results["F"] = _analyze_prompt(enriched_f, model="claude-opus-4-6", latency=latency_f)

    return results


def _analyze_prompt(prompt: str, model: str | None, latency: float) -> dict:
    """Analyze a prompt for metrics."""
    words = prompt.split() if not prompt.startswith("[") else []
    return {
        "prompt": prompt,
        "word_count": len(words),
        "enrichment_model": model,
        "enrichment_latency_ms": round(latency, 1),
        "has_temporal_markers": any(m in prompt.lower() for m in ["initially", "then "]),
        "has_motion_endpoints": any(m in prompt.lower() for m in [
            "until", "reaching", "to the", "onto the", "from ", "toward",
            "level with", "at chest height", "at shoulder", "to grip",
        ]),
    }


# ---------------------------------------------------------------------------
# Output formatting
# ---------------------------------------------------------------------------

def format_comparison_table(all_results: dict[str, dict[str, dict]]) -> str:
    """Format results as a readable markdown comparison table."""
    lines = ["# Prompt A/B Test — Comparison Report", ""]
    lines.append(f"Generated: {datetime.now(timezone.utc).isoformat()}")
    lines.append("")

    for shot_id, variants in all_results.items():
        lines.append(f"## {shot_id}")
        lines.append("")

        # Summary table
        lines.append("| Variant | Words | Model | Latency | Temporal? | Endpoints? |")
        lines.append("|---------|-------|-------|---------|-----------|------------|")
        for vid, v in variants.items():
            label = VARIANT_LABELS.get(vid, vid)
            temporal = "Yes" if v["has_temporal_markers"] else "No"
            endpoints = "Yes" if v["has_motion_endpoints"] else "No"
            model = v["enrichment_model"] or "none"
            latency = f"{v['enrichment_latency_ms']}ms" if v["enrichment_latency_ms"] else "-"
            lines.append(f"| **{vid}** ({label}) | {v['word_count']} | {model} | {latency} | {temporal} | {endpoints} |")

        lines.append("")

        # Full prompts
        for vid, v in variants.items():
            lines.append(f"### Variant {vid}: {VARIANT_LABELS.get(vid, vid)}")
            lines.append("```")
            lines.append(v["prompt"])
            lines.append("```")
            lines.append("")

        lines.append("---")
        lines.append("")

    return "\n".join(lines)


# ---------------------------------------------------------------------------
# Output saving
# ---------------------------------------------------------------------------

def save_results(project: str, all_results: dict, output_dir: Path = None) -> Path:
    """Save comparison report and individual prompts."""
    if output_dir is None:
        timestamp = datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
        output_dir = projects_root() / project / "tests" / "prompt-ab" / f"run_{timestamp}"

    output_dir.mkdir(parents=True, exist_ok=True)

    # Save comparison table
    comparison = format_comparison_table(all_results)
    (output_dir / "comparison.md").write_text(comparison)

    # Save individual prompts
    prompts_dir = output_dir / "prompts"
    prompts_dir.mkdir(exist_ok=True)
    for shot_id, variants in all_results.items():
        for vid, v in variants.items():
            (prompts_dir / f"{shot_id}_{vid}.txt").write_text(v["prompt"])

    # Save raw metadata
    metadata = {}
    for shot_id, variants in all_results.items():
        metadata[shot_id] = {}
        for vid, v in variants.items():
            metadata[shot_id][vid] = {k: v2 for k, v2 in v.items() if k != "prompt"}
    (output_dir / "enrichment_metadata.json").write_text(json.dumps(metadata, indent=2))

    # Save scoring template
    scoring = _build_scoring_template(all_results)
    (output_dir / "scoring.md").write_text(scoring)

    return output_dir


def _build_scoring_template(all_results: dict) -> str:
    """Generate a manual scoring template for Phase 4."""
    lines = [
        "# Prompt A/B Test — Scoring Sheet",
        "",
        "Score each generated video 1-5 on these criteria.",
        "Leave blank for variants not sent to Kling.",
        "",
    ]
    for shot_id in all_results:
        lines.append(f"## {shot_id}")
        lines.append("")
        lines.append("| Variant | Motion Fidelity | Identity | Geometry | Background | Editorial |")
        lines.append("|---------|----------------|----------|----------|------------|-----------|")
        for vid in all_results[shot_id]:
            lines.append(f"| {vid} | | | | | |")
        lines.append("")
        lines.append("**Notes:**")
        lines.append("")
        lines.append("---")
        lines.append("")
    return "\n".join(lines)


# ---------------------------------------------------------------------------
# Phase 3: Generation via StepRunner
# ---------------------------------------------------------------------------

def submit_for_generation(
    project: str,
    shot_id: str,
    prompt: str,
    model: str = "kling-v3",
    negative_prompt: str | None = None,
    generate_audio: bool = False,
    elements: str | None = None,
    duration: int | None = None,
    dry_run: bool = False,
) -> dict:
    """Submit a single prompt variant to StepRunner for generation.

    Returns metadata dict with cost, output path, etc.
    """
    if dry_run:
        print(f"  [DRY RUN] Would generate {shot_id} with model={model}")
        print(f"  Prompt: {prompt[:80]}...")
        if negative_prompt:
            print(f"  Negative: {negative_prompt[:60]}...")
        print(f"  Audio: {generate_audio}, Duration: {duration}, Elements: {elements or 'none'}")
        return {"dry_run": True, "shot_id": shot_id, "prompt": prompt}

    # Build the command for dispatch_cli.py
    cmd_parts = [
        sys.executable, str(PROJECT_ROOT / "tools" / "dispatch_cli.py"),
        "--project", project,
        "--shot", shot_id,
        "--model", model,
        "--prompt", prompt,
    ]

    if duration:
        cmd_parts.extend(["--duration", str(duration)])
    if elements:
        cmd_parts.extend(["--elements", elements])
    if negative_prompt:
        cmd_parts.extend(["--negative-prompt", negative_prompt])
    if generate_audio:
        cmd_parts.append("--generate-audio")

    print(f"\n  Submitting {shot_id} → {model}")
    print(f"  Prompt ({len(prompt.split())}w): {prompt}")
    if negative_prompt:
        print(f"  Negative prompt: {negative_prompt[:60]}...")
    print(f"  Audio: {generate_audio}")

    import subprocess
    try:
        result = subprocess.run(cmd_parts, capture_output=True, text=True, cwd=str(PROJECT_ROOT), timeout=420)
    except subprocess.TimeoutExpired:
        print(f"  TIMEOUT: subprocess exceeded 420s for {shot_id}")
        return {"error": "subprocess timeout (420s)", "shot_id": shot_id}

    if result.returncode != 0:
        print(f"  ERROR: {result.stderr[:200]}")
        return {"error": result.stderr, "shot_id": shot_id}

    print(f"  SUCCESS: {result.stdout[-200:]}")
    return {"success": True, "shot_id": shot_id, "stdout": result.stdout}


# ---------------------------------------------------------------------------
# Phase 3 extensions: model comparison, negative prompt, audio, batch
# ---------------------------------------------------------------------------

def run_model_comparison(project: str, shot_id: str, prompt: str, models: list[str], elements: str | None = None, dry_run: bool = False):
    """Run the same prompt on multiple models (e.g., V3 vs O3)."""
    print(f"\n=== Model Comparison: {shot_id} ===")
    results = {}
    for model in models:
        print(f"\n--- Model: {model} ---")
        result = submit_for_generation(
            project, shot_id, prompt, model=model,
            elements=elements,
            dry_run=dry_run,
        )
        results[model] = result
    return results


def run_negative_prompt_test(project: str, shot_id: str, prompt: str, model: str = "kling-v3", dry_run: bool = False):
    """Run same prompt with and without negative prompt."""
    print(f"\n=== Negative Prompt Test: {shot_id} ===")
    results = {}
    for label, neg in [("without_negative", None), ("with_negative", NEGATIVE_PROMPT_ACTION)]:
        print(f"\n--- {label} ---")
        result = submit_for_generation(
            project, shot_id, prompt, model=model,
            negative_prompt=neg, dry_run=dry_run,
        )
        results[label] = result
    return results


def run_audio_toggle_test(project: str, shot_id: str, prompt: str, model: str = "kling-v3", dry_run: bool = False):
    """Run same prompt with audio on vs off."""
    print(f"\n=== Audio Toggle Test: {shot_id} ===")
    results = {}
    for label, audio in [("audio_on", True), ("audio_off", False)]:
        print(f"\n--- {label} ---")
        result = submit_for_generation(
            project, shot_id, prompt, model=model,
            generate_audio=audio, dry_run=dry_run,
        )
        results[label] = result
    return results


def run_batch_test(project: str, shot_ids: list[str], prompts: dict[str, str], model: str = "kling-v3", dry_run: bool = False):
    """Compare single-shot I2V vs multi_prompt batch for same shots."""
    print(f"\n=== Batch vs Single Test: {', '.join(shot_ids)} ===")
    if len(shot_ids) < 2:
        print("  Need at least 2 shots for batch test.")
        return {}

    results = {"single": {}, "batch": {}}

    # Single shots
    print("\n--- Individual I2V ---")
    for sid in shot_ids:
        if sid in prompts:
            result = submit_for_generation(project, sid, prompts[sid], model=model, dry_run=dry_run)
            results["single"][sid] = result

    # Batch (multi_prompt)
    print(f"\n--- Multi-prompt batch ({len(shot_ids)} shots) ---")
    if dry_run:
        print(f"  [DRY RUN] Would submit {len(shot_ids)}-shot batch")
        for i, sid in enumerate(shot_ids):
            print(f"    Shot {i+1}: {prompts.get(sid, '???')[:60]}...")
    else:
        cmd_parts = [
            sys.executable, str(PROJECT_ROOT / "tools" / "dispatch_cli.py"),
            "--project", project,
            "--shots", ",".join(shot_ids),
            "--model", model,
        ]
        import subprocess
        result = subprocess.run(cmd_parts, capture_output=True, text=True, cwd=str(PROJECT_ROOT))
        results["batch"] = {"stdout": result.stdout, "returncode": result.returncode}

    return results


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

def main():
    parser = argparse.ArgumentParser(description="Prompt A/B Test Harness")
    parser.add_argument("--project", required=True, help="Project name (e.g., tartarus)")
    parser.add_argument("--shots", default=",".join(DEFAULT_TEST_SHOTS),
                        help="Comma-separated shot IDs")
    parser.add_argument("--phase", choices=["prompts", "generate", "all"], default="prompts",
                        help="Which phase to run")
    parser.add_argument("--variants", default="A,B,C,D,E,F",
                        help="Which variants to generate/submit (e.g., D,F)")
    parser.add_argument("--models", default=None,
                        help="Comma-separated models for comparison (e.g., kling-v3,kling-o3)")
    parser.add_argument("--test-negative", action="store_true",
                        help="Run negative prompt A/B test")
    parser.add_argument("--test-audio", action="store_true",
                        help="Run generate_audio on/off test")
    parser.add_argument("--test-batch", action="store_true",
                        help="Run single vs batch comparison")
    parser.add_argument("--dry-run", action="store_true",
                        help="Don't actually call APIs (prompts phase) or generate (generate phase)")
    args = parser.parse_args()

    shot_ids = [s.strip() for s in args.shots.split(",")]
    variant_ids = [v.strip() for v in args.variants.split(",")]

    print(f"Project: {args.project}")
    print(f"Shots: {shot_ids}")
    print(f"Phase: {args.phase}")
    print()

    # Load plan data
    shots = load_plan_shots(args.project, shot_ids)
    if not shots:
        print("ERROR: No shots found. Check project name and shot IDs.")
        sys.exit(1)

    print(f"Loaded {len(shots)} shots: {list(shots.keys())}")
    print()

    # -----------------------------------------------------------------------
    # Phase 1: Prompt comparison
    # -----------------------------------------------------------------------
    if args.phase in ("prompts", "all"):
        print("=" * 60)
        print("PHASE 1: Generating prompt variants")
        print("=" * 60)
        print()

        all_results = {}
        for shot_id, shot in shots.items():
            print(f"--- {shot_id} ---")
            print(f"  Source: {shot.get('source_text', 'N/A')[:80]}...")
            variants = generate_variants(shot)
            all_results[shot_id] = variants

            for vid, v in variants.items():
                status = "OK" if not v["prompt"].startswith("[") else v["prompt"]
                print(f"  {vid}: {v['word_count']}w, {status[:60]}")
            print()

        # Save results
        output_dir = save_results(args.project, all_results)
        print(f"\nResults saved to: {output_dir}")
        print("  comparison.md  — Side-by-side prompt table")
        print("  prompts/       — Individual prompt files")
        print("  scoring.md     — Manual scoring template")
        print()

        # Print the comparison to stdout too
        print(format_comparison_table(all_results))

    # -----------------------------------------------------------------------
    # Phase 3: Generation
    # -----------------------------------------------------------------------
    if args.phase in ("generate", "all"):
        print("=" * 60)
        print("PHASE 3: Submitting selected variants for generation")
        print("=" * 60)
        print()

        # First generate the prompts to get the text
        all_results = {}
        for shot_id, shot in shots.items():
            all_results[shot_id] = generate_variants(shot)

        for shot_id in shots:
            shot = shots[shot_id]

            # Hero frame pre-check — I2V requires a start image
            # Uses same glob pattern as find_hero_frame() in dispatch_cli.py
            import re as _re
            _ep_match = _re.match(r"EP(\d+)", shot_id)
            _ep_num = int(_ep_match.group(1)) if _ep_match else 1
            _sh_match = _re.search(r"SH(\d+)", shot_id)
            _shot_num_padded = f"{int(_sh_match.group(1)):03d}" if _sh_match else ""
            _shot_glob = f"shot_{_shot_num_padded}"
            _projects_root = projects_root()
            _hero_candidates = []
            for _search_dir in [
                _projects_root / args.project / "output" / "frames" / f"ep_{_ep_num:03d}",
                _projects_root / args.project / "output" / "previs" / f"ep_{_ep_num:03d}",
            ]:
                if _search_dir.exists():
                    _hero_candidates = sorted(
                        _search_dir.glob(f"*{_shot_glob}*"),
                        key=lambda p: p.stat().st_mtime, reverse=True
                    )
                    if _hero_candidates:
                        break
            if not _hero_candidates:
                print(f"SKIP {shot_id}: No hero frame found (required for I2V)")
                continue
            print(f"  Hero frame: {_hero_candidates[0].name}")

            # Duration from plan data
            duration = shot.get("routing_data", {}).get("target_editorial_duration_s") or 5

            for vid in variant_ids:
                if vid not in all_results.get(shot_id, {}):
                    continue
                variant = all_results[shot_id][vid]
                if variant["prompt"].startswith("["):
                    print(f"SKIP {shot_id}/{vid}: {variant['prompt']}")
                    continue

                submit_for_generation(
                    args.project, shot_id, variant["prompt"],
                    duration=duration,
                    dry_run=args.dry_run,
                )

        # Optional extension tests
        if args.models:
            models = [m.strip() for m in args.models.split(",")]
            for shot_id in shots:
                shot = shots[shot_id]
                best_variant = all_results[shot_id].get(variant_ids[0], {})
                if best_variant.get("prompt", "").startswith("["):
                    continue
                asset_chars = shot.get("asset_data", {}).get("characters", [])
                char_ids = [c["char_id"] for c in asset_chars if isinstance(c, dict) and "char_id" in c]
                elements_str = ",".join(char_ids) if char_ids else None
                run_model_comparison(
                    args.project, shot_id, best_variant["prompt"],
                    models=models, elements=elements_str, dry_run=args.dry_run,
                )

        if args.test_negative:
            for shot_id in shots:
                best_variant = all_results[shot_id].get(variant_ids[0], {})
                if best_variant.get("prompt", "").startswith("["):
                    continue
                run_negative_prompt_test(
                    args.project, shot_id, best_variant["prompt"],
                    dry_run=args.dry_run,
                )

        if args.test_audio:
            for shot_id in shots:
                best_variant = all_results[shot_id].get(variant_ids[0], {})
                if best_variant.get("prompt", "").startswith("["):
                    continue
                run_audio_toggle_test(
                    args.project, shot_id, best_variant["prompt"],
                    dry_run=args.dry_run,
                )

        if args.test_batch and len(shot_ids) >= 2:
            prompts = {}
            for sid in shot_ids:
                best = all_results.get(sid, {}).get(variant_ids[0], {})
                if best.get("prompt") and not best["prompt"].startswith("["):
                    prompts[sid] = best["prompt"]
            run_batch_test(
                args.project, shot_ids[:2], prompts,
                dry_run=args.dry_run,
            )


if __name__ == "__main__":
    main()
