#!/usr/bin/env python3
"""First+Last Frame Video Model Shootout.

Tests multiple video models with the same start and end frames.
Routes through StepRunner to test the real production code path.

Usage:
    python3 recoil/tools/shootout/run_shootout.py
    python3 recoil/tools/shootout/run_shootout.py --dry-run
    python3 recoil/tools/shootout/run_shootout.py --models kling-v3,wan-2.7-i2v
"""

import argparse
import json
import logging
import sys
import time
from datetime import datetime
from pathlib import Path

# Ensure recoil is importable
SCRIPT_DIR = Path(__file__).parent
RECOIL_ROOT = SCRIPT_DIR.parent.parent
sys.path.insert(0, str(RECOIL_ROOT))

from recoil.execution.step_runner import StepRunner
from recoil.execution.step_types import ProjectPaths
from recoil.execution.execution_store import ExecutionStore
from recoil.core import model_profiles
from recoil.pipeline.core.dispatch import dispatch
from recoil.pipeline.core.dispatch_context import DispatchContext
from recoil.pipeline.core.cost import read_cost_from_result

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s: %(message)s")
logger = logging.getLogger("shootout")


def load_config() -> dict:
    config_path = SCRIPT_DIR / "config.json"
    with open(config_path) as f:
        return json.load(f)


def resolve_path(rel_path: str) -> Path:
    """Resolve a path relative to the shootout directory."""
    p = SCRIPT_DIR / rel_path
    return p.resolve()


def run_shootout(config: dict, dry_run: bool = False, models_override: list = None):
    """Execute the shootout across all configured models."""

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_base = SCRIPT_DIR / config["output_dir"] / timestamp
    output_base.mkdir(parents=True, exist_ok=True)

    start_frame = resolve_path(config["start_frame"])
    if not start_frame.exists():
        logger.error("Start frame not found: %s", start_frame)
        sys.exit(1)

    # Use hard end frame by default (more interesting test)
    end_frame_key = "hard"
    end_frame_path = config["end_frames"].get(end_frame_key)
    end_frame = resolve_path(end_frame_path) if end_frame_path else None

    if end_frame and not end_frame.exists():
        logger.error("End frame not found: %s", end_frame)
        sys.exit(1)

    if end_frame is None:
        logger.warning(
            "END FRAME IS NULL in config.json['end_frames']['%s']. "
            "This shootout will test FIRST-FRAME-ONLY I2V, not first+last frame. "
            "Generate end frames first: see config.json['end_frame_generation'] for prompts.",
            end_frame_key,
        )

    models = models_override or config["models"]
    prompt = config["prompt"]
    duration = config["duration"]

    logger.info("=" * 60)
    logger.info("SHOOTOUT: First+Last Frame Video Model Comparison")
    logger.info("=" * 60)
    logger.info("Start frame: %s", start_frame)
    logger.info("End frame: %s (%s)", end_frame, end_frame_key)
    logger.info("Prompt: %s", prompt[:80])
    logger.info("Duration: %ds, Resolution: %s", duration, config["resolution"])
    logger.info("Models: %s", ", ".join(models))
    logger.info("Output: %s", output_base)
    logger.info("-" * 60)

    if dry_run:
        logger.info("DRY RUN — estimating costs only")
        for model in models:
            cost_per_s = model_profiles.get_cost(model)
            est_cost = cost_per_s * duration
            logger.info("  %s: $%.3f (%.3f/s x %ds)", model, est_cost, cost_per_s, duration)
        total = sum(model_profiles.get_cost(m) * duration for m in models)
        logger.info("  Total estimated: $%.3f", total)
        return

    # Initialize StepRunner with correct constructor signatures
    store = ExecutionStore(project="afterimage")
    paths = ProjectPaths.for_episode(project="afterimage", episode=1)
    runner = StepRunner(store=store, paths=paths, validate_frames=False)

    ctx = DispatchContext(
        caller_id="shootout_cli",
        step_runner=runner,
        project="afterimage",
        episode=1,
    )

    manifest = {
        "timestamp": timestamp,
        "config": config,
        "start_frame": str(start_frame),
        "end_frame": str(end_frame),
        "results": {},
    }

    for model in models:
        model_dir = output_base / model.replace(".", "_")
        model_dir.mkdir(parents=True, exist_ok=True)

        shot_id = f"shootout_{timestamp}_{model.replace('.', '_').replace('-', '_')}"
        logger.info("\n--- %s ---", model)

        wall_start = time.time()

        try:
            receipt = dispatch(
                "video_i2v",
                {
                    "shot_id": shot_id,
                    "prompt": prompt,
                    "model": model,
                    "start_frame": start_frame,
                    "end_frame": end_frame,
                    "duration": duration,
                    "aspect_ratio": "9:16",
                    "generate_audio": False,
                },
                context=ctx,
            )
            result = receipt.run_result

            wall_time = time.time() - wall_start

            # Copy output to shootout results
            if result.success and result.output_path:
                import shutil
                dest = model_dir / "output.mp4"
                shutil.copy2(result.output_path, dest)
                logger.info("  Output: %s", dest)

            cost_usd = read_cost_from_result(result)
            manifest["results"][model] = {
                "success": result.success,
                "cost_usd": cost_usd,
                "wall_time_s": round(wall_time, 1),
                "error": result.error,
                "output_path": str(model_dir / "output.mp4") if result.success else None,
                "model": model,
            }

            status = "OK" if result.success else f"FAILED: {result.error}"
            logger.info("  Status: %s", status)
            logger.info("  Cost: $%.3f, Wall time: %.1fs", cost_usd, wall_time)

        except Exception as e:
            wall_time = time.time() - wall_start
            logger.error("  EXCEPTION: %s", e)
            manifest["results"][model] = {
                "success": False,
                "cost_usd": 0.0,
                "wall_time_s": round(wall_time, 1),
                "error": str(e),
                "output_path": None,
                "model": model,
            }

    # Write manifest
    manifest_path = output_base / "manifest.json"
    with open(manifest_path, "w") as f:
        json.dump(manifest, f, indent=2)

    # Print summary table
    logger.info("\n" + "=" * 60)
    logger.info("SUMMARY")
    logger.info("=" * 60)
    logger.info("%-15s %-8s %-10s %-10s %s", "Model", "Status", "Cost", "Time", "Output")
    logger.info("-" * 60)
    for model, r in manifest["results"].items():
        status = "OK" if r["success"] else "FAIL"
        cost = f"${r['cost_usd']:.3f}"
        wall = f"{r['wall_time_s']:.0f}s"
        out = r["output_path"] or r["error"][:40] if r["error"] else "—"
        logger.info("%-15s %-8s %-10s %-10s %s", model, status, cost, wall, out)

    total_cost = sum(r["cost_usd"] for r in manifest["results"].values())
    logger.info("-" * 60)
    logger.info("Total cost: $%.3f", total_cost)
    logger.info("Manifest: %s", manifest_path)
    logger.info("Review outputs in: %s", output_base)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="First+Last Frame Video Model Shootout")
    parser.add_argument("--dry-run", action="store_true", help="Estimate costs only")
    parser.add_argument("--models", type=str, help="Comma-separated model list override")
    args = parser.parse_args()

    config = load_config()
    models = args.models.split(",") if args.models else None
    run_shootout(config, dry_run=args.dry_run, models_override=models)
