#!/usr/bin/env python3
"""
run_vignette.py — Ad-hoc video vignette generator for client work.

Bypasses the plan/shot/bible path of test_via_steprunner.py and drives
StepRunner directly for one-off vignettes where you already have a
start frame (and optionally an end frame) plus handwritten prompts.

Two modes:

  multi-shot   Kling native multi-prompt (3-6 shots in one generation)
               with per-shot prompt + duration.

  in-between   Wan 2.7 I2V with start + end frame interpolation
               (single prompt, single clip, tweened).

Usage:
    # Kling multi-shot (3 shots, 10s total)
    python -m tools.run_vignette \\
        --project driver-beware \\
        --mode multi-shot \\
        --slug V1_stay_in_car \\
        --model kling-v3 \\
        --aspect 16:9 \\
        --start-frame output/frames/ep_001/clean/v1_stay_in_car_start.png \\
        --shots-json '[{"prompt":"...","duration":4},{"prompt":"...","duration":3},{"prompt":"...","duration":3}]'

    # Wan 2.7 In-Between (start + end frame, single tween clip)
    python -m tools.run_vignette \\
        --project driver-beware \\
        --mode in-between \\
        --slug V4_cautionary \\
        --model wan-2.7-i2v \\
        --aspect 16:9 \\
        --duration 6 \\
        --start-frame output/frames/ep_001/clean/v4_cautionary_start.png \\
        --end-frame output/frames/ep_001/clean/v4_cautionary_end.png \\
        --prompt "Smooth cel-animation tween..."

Outputs land under ExecutionStore's managed paths (output/video/ep_001/).
"""

import argparse
import json
import logging
import sys
import time
from pathlib import Path

_PROJECT_ROOT = Path(__file__).parent.parent
if str(_PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(_PROJECT_ROOT))

from lib.constants import PROJECTS_ROOT

logger = logging.getLogger("pipeline.run_vignette")


def _get_store(project: str):
    from lib.execution_store import ExecutionStore
    return ExecutionStore(project)


def _resolve_frame(project: str, p: str) -> Path:
    path = Path(p)
    if not path.is_absolute():
        path = PROJECTS_ROOT / project / p
    if not path.is_file():
        raise FileNotFoundError(f"Frame not found: {path}")
    return path


def run_multi_shot(
    project: str,
    slug: str,
    model: str,
    start_frame: Path,
    shots: list[dict],
    aspect_ratio: str = "16:9",
    episode: int = 1,
) -> dict:
    """Run a Kling-style multi-shot vignette.

    shots: list of {"prompt": str, "duration": int}
    """
    from orchestrator.step_runner import StepRunner
    from orchestrator.step_types import ProjectPaths

    store = _get_store(project)
    paths = ProjectPaths.for_episode(project, episode)
    runner = StepRunner(store=store, paths=paths)

    # Build batch (one entry per shot) and multi_prompt_sequence.
    # Kling multi-shot via fal.ai caps per-shot prompts at 512 chars.
    KLING_MULTI_SHOT_PROMPT_MAX = 512
    batch = []
    sequence = []
    for i, s in enumerate(shots, start=1):
        shot_id = f"{slug.upper()}_SH{i:02d}"
        dur = int(s.get("duration", 5))
        prompt = s["prompt"]
        if "kling" in model and len(prompt) > KLING_MULTI_SHOT_PROMPT_MAX:
            raise ValueError(
                f"Shot {i} prompt is {len(prompt)} chars — Kling multi-shot "
                f"caps per-shot prompts at {KLING_MULTI_SHOT_PROMPT_MAX}. "
                f"Trim this prompt before submitting (fal.ai charges on submission)."
            )
        batch.append({"shot_id": shot_id, "_api_duration": dur})
        sequence.append({"index": i, "prompt": prompt, "duration": dur})

    total = sum(s["_api_duration"] for s in batch)
    logger.info(
        "Multi-shot %s: %d shots, %ds total, model=%s, aspect=%s",
        slug, len(batch), total, model, aspect_ratio,
    )

    t0 = time.time()
    results = runner.execute_multi_shot(
        batch=batch,
        multi_prompt_sequence=sequence,
        model=model,
        start_frame=start_frame,
        aspect_ratio=aspect_ratio,
    )
    elapsed = time.time() - t0

    return {
        "slug": slug,
        "mode": "multi-shot",
        "elapsed_s": round(elapsed, 1),
        "shots": [
            {
                "shot_id": r.shot_id,
                "success": r.success,
                "output_path": str(r.output_path) if r.output_path else None,
                "cost_usd": r.cost_usd,
                "error": getattr(r, "error", None),
            }
            for r in results
        ],
        "total_cost": sum(r.cost_usd for r in results),
    }


def run_in_between(
    project: str,
    slug: str,
    model: str,
    start_frame: Path,
    end_frame: Path | None,
    prompt: str,
    duration: int = 6,
    aspect_ratio: str = "16:9",
    episode: int = 1,
    negative_prompt: str = None,
) -> dict:
    """Run a single-shot video. If end_frame is given, runs an In-Between
    tween (start → end). If end_frame is None, runs a normal I2V from the
    start frame."""
    from orchestrator.step_runner import StepRunner
    from orchestrator.step_types import ProjectPaths

    store = _get_store(project)
    paths = ProjectPaths.for_episode(project, episode)
    runner = StepRunner(store=store, paths=paths)

    shot_id = f"{slug.upper()}_SH01"
    logger.info(
        "In-between %s: %ds, model=%s, aspect=%s", slug, duration, model, aspect_ratio,
    )

    if negative_prompt is None and "wan" in model:
        negative_prompt = "morphing, melting, teleporting, fast cuts, text, watermark"

    t0 = time.time()
    result = runner.execute_video(
        shot_id=shot_id,
        prompt=prompt,
        model=model,
        start_frame=start_frame,
        end_frame=end_frame,
        duration=duration,
        aspect_ratio=aspect_ratio,
        negative_prompt=negative_prompt,
    )
    elapsed = time.time() - t0

    return {
        "slug": slug,
        "mode": "in-between",
        "elapsed_s": round(elapsed, 1),
        "shot_id": shot_id,
        "success": result.success,
        "output_path": str(result.output_path) if result.output_path else None,
        "cost_usd": result.cost_usd,
        "error": getattr(result, "error", None),
    }


def main():
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s [%(levelname)s] %(message)s",
        datefmt="%H:%M:%S",
    )
    parser = argparse.ArgumentParser(description="Ad-hoc video vignette runner")
    parser.add_argument("--project", required=True)
    parser.add_argument("--slug", required=True, help="Vignette identifier (used for shot IDs)")
    parser.add_argument(
        "--mode", required=True, choices=["multi-shot", "in-between", "single"],
        help="multi-shot: Kling multi-prompt batch | in-between: Wan I2V with start+end | single: single-shot I2V with start frame only",
    )
    parser.add_argument("--model", required=True, help="e.g. kling-v3, wan-2.7-i2v")
    parser.add_argument("--aspect", default="16:9")
    parser.add_argument("--start-frame", required=True, help="Path (abs or project-relative)")
    parser.add_argument("--end-frame", help="Required for --mode in-between")
    parser.add_argument(
        "--shots-json",
        help='Multi-shot: JSON list of {"prompt","duration"} (required for --mode multi-shot)',
    )
    parser.add_argument("--prompt", help="Single prompt (required for --mode in-between)")
    parser.add_argument("--duration", type=int, default=6, help="Seconds (in-between mode)")
    parser.add_argument("--negative-prompt", help="Override negative prompt")
    parser.add_argument("--episode", type=int, default=1)
    args = parser.parse_args()

    start_frame = _resolve_frame(args.project, args.start_frame)

    if args.mode == "multi-shot":
        if not args.shots_json:
            parser.error("--shots-json required for --mode multi-shot")
        shots = json.loads(args.shots_json)
        result = run_multi_shot(
            project=args.project,
            slug=args.slug,
            model=args.model,
            start_frame=start_frame,
            shots=shots,
            aspect_ratio=args.aspect,
            episode=args.episode,
        )
    elif args.mode == "in-between":
        if not args.end_frame or not args.prompt:
            parser.error("--end-frame and --prompt required for --mode in-between")
        end_frame = _resolve_frame(args.project, args.end_frame)
        result = run_in_between(
            project=args.project,
            slug=args.slug,
            model=args.model,
            start_frame=start_frame,
            end_frame=end_frame,
            prompt=args.prompt,
            duration=args.duration,
            aspect_ratio=args.aspect,
            episode=args.episode,
            negative_prompt=args.negative_prompt,
        )
    else:  # single
        if not args.prompt:
            parser.error("--prompt required for --mode single")
        result = run_in_between(
            project=args.project,
            slug=args.slug,
            model=args.model,
            start_frame=start_frame,
            end_frame=None,
            prompt=args.prompt,
            duration=args.duration,
            aspect_ratio=args.aspect,
            episode=args.episode,
            negative_prompt=args.negative_prompt,
        )

    print(json.dumps(result, indent=2))


if __name__ == "__main__":
    main()
