#!/usr/bin/env python3
"""
generate_ep001_videos.py — Generate 9 videos from EP001 keyframes.

Primary: Veo 3.1 image-to-video (duration clamped to 4-8s).
Fallback: Kling V3 I2V (with corrected poll endpoint for image2video).

Skips shots that already have an output file on disk.
Processes shots SEQUENTIALLY to avoid rate limiting.

Usage:
    python3 tools/generate_ep001_videos.py
    python3 tools/generate_ep001_videos.py --kling-only   # Skip Veo, use Kling for all
"""

import base64
import json
import logging
import os
import sys
import time
import urllib.request
import urllib.error
from pathlib import Path

# Ensure starsend root is on the path
STARSEND_ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(STARSEND_ROOT))

from lib.api_client import GoogleGenaiClient, KlingClient, GenerationResult

# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------

PROJECTS_ROOT = Path(os.path.expanduser("~/Dropbox/CLAUDE_PROJECTS/projects"))
PROJECT = "starsend-test"
EPISODE = "ep_001"

PLAN_PATH = PROJECTS_ROOT / PROJECT / "state" / "starsend" / "plans" / "ep_001_plan.json"
FRAMES_DIR = PROJECTS_ROOT / PROJECT / "output" / "frames" / "ep_001"
OUTPUT_DIR = PROJECTS_ROOT / PROJECT / "output" / "video" / "ep_001"

# Shot ID -> output filename mapping
SHOT_FILENAME_MAP = {
    "EP001_SH01": "shot_001.mp4",
    "EP001_SH02": "shot_002.mp4",
    "EP001_SH02A": "shot_002a.mp4",
    "EP001_SH03": "shot_003.mp4",
    "EP001_SH03A": "shot_003a.mp4",
    "EP001_SH04": "shot_004.mp4",
    "EP001_SH05": "shot_005.mp4",
    "EP001_SH06": "shot_006.mp4",
    "EP001_SH07": "shot_007.mp4",
}

# Shot ID -> keyframe filename mapping
SHOT_KEYFRAME_MAP = {
    "EP001_SH01": "STA_EP001_S00_shot_EP001_SH01.png",
    "EP001_SH02": "STA_EP001_S00_shot_EP001_SH02.png",
    "EP001_SH02A": "STA_EP001_S00_shot_EP001_SH02A.png",
    "EP001_SH03": "STA_EP001_S00_shot_EP001_SH03.png",
    "EP001_SH03A": "STA_EP001_S00_shot_EP001_SH03A.png",
    "EP001_SH04": "STA_EP001_S00_shot_EP001_SH04.png",
    "EP001_SH05": "STA_EP001_S00_shot_EP001_SH05.png",
    "EP001_SH06": "STA_EP001_S00_shot_EP001_SH06.png",
    "EP001_SH07": "STA_EP001_S00_shot_EP001_SH07.png",
}

VEO_MODEL = "veo-3.1-generate-preview"
VEO_MIN_DURATION = 4  # Veo 3.1 minimum duration
VEO_MAX_DURATION = 8  # Veo 3.1 maximum duration
ASPECT_RATIO = "9:16"
TIMEOUT_S = 600

# Kling I2V polling config (override the library's aggressive backoff)
KLING_POLL_INTERVAL = 15   # seconds between polls
KLING_POLL_MAX = 40        # max poll attempts (= 10 min at 15s intervals)

# ---------------------------------------------------------------------------
# Logging
# ---------------------------------------------------------------------------

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    datefmt="%H:%M:%S",
)
logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Prompt construction
# ---------------------------------------------------------------------------

def build_video_prompt(shot: dict) -> str:
    """Build a video prompt for a shot.

    Priority:
      1. compiled_prompts.veo_t2v (pre-compiled, ideal)
      2. Construct from kinetic_action + prompt_skeleton + duration
    """
    prompts = shot.get("compiled_prompts", {})
    duration = shot["routing_data"]["target_editorial_duration_s"]

    # Best case: pre-compiled veo prompt
    if "veo_t2v" in prompts:
        return prompts["veo_t2v"]

    # Construct from shot data
    pd = shot.get("prompt_data", {})
    skeleton = pd.get("prompt_skeleton", {})
    kinetic = pd.get("kinetic_action", "")
    shot_type = pd.get("shot_type", "")
    focal = pd.get("focal_length", "")

    # Lighting description
    lighting_desc = ""
    lighting = pd.get("lighting", {})
    sources = lighting.get("sources", [])
    if sources:
        parts = []
        for src in sources:
            parts.append(
                f"{src.get('color_temp', '')} {src.get('quality', '')} light from {src.get('motivator', '')}"
            )
        lighting_desc = ". ".join(parts)

    subject = skeleton.get("subject_line", "")
    environment = skeleton.get("environment_line", "")
    action = skeleton.get("action_line", "")
    emotion = skeleton.get("emotion_line", "")

    lines = [f"Cinematic video, {duration} seconds."]
    if shot_type and focal:
        lines.append(f"{shot_type}, {focal}.")
    if subject:
        lines.append(subject + ".")
    if kinetic:
        lines.append(f"Motion: {kinetic}.")
    if environment:
        lines.append(environment + ".")
    if lighting_desc:
        lines.append(lighting_desc + ".")
    if emotion:
        lines.append(f"{emotion.capitalize()}.")
    lines.append("Photorealistic, cinematic lighting, high budget film.")

    prompt = " ".join(lines)
    prompt = prompt.replace("..", ".").replace("  ", " ")
    return prompt


def build_kling_prompt(shot: dict) -> str:
    """Build a Kling I2V prompt for a shot."""
    prompts = shot.get("compiled_prompts", {})
    if "kling_i2v" in prompts:
        return prompts["kling_i2v"]
    # Fallback: construct from kinetic action
    pd = shot.get("prompt_data", {})
    kinetic = pd.get("kinetic_action", "subtle ambient motion")
    shot_type = pd.get("shot_type", "")
    emotion = pd.get("prompt_skeleton", {}).get("emotion_line", "")
    parts = []
    if shot_type:
        parts.append(shot_type)
    parts.append(kinetic)
    if emotion:
        parts.append(emotion)
    return ", ".join(parts) + ". Cinematic, photorealistic."


# ---------------------------------------------------------------------------
# Kling I2V with corrected polling
# ---------------------------------------------------------------------------

def kling_i2v_generate(shot: dict, keyframe_path: Path) -> GenerationResult:
    """Generate video via Kling V3 I2V with corrected poll endpoint.

    The library's poll_status() hardcodes /v1/videos/text2video/{task_id}
    which fails for I2V tasks. This function handles polling correctly
    using /v1/videos/image2video/{task_id}.
    """
    client = KlingClient()
    if not client.is_available():
        return GenerationResult(
            success=False, model="kling-v3",
            error="Kling API keys not configured",
        )

    # Read and base64-encode the keyframe
    image_b64 = base64.b64encode(keyframe_path.read_bytes()).decode()

    prompt = build_kling_prompt(shot)
    duration = shot["routing_data"]["target_editorial_duration_s"]

    payload = {
        "mode": "image2video",
        "prompt": prompt,
        "image": image_b64,
        "duration": duration,
        "aspect_ratio": "9:16",
    }

    logger.info("  Kling I2V: submitting...")
    logger.info("  Kling prompt: %s", prompt[:100])
    job = client.submit(payload)

    if job.status == "failed":
        return job.result or GenerationResult(
            success=False, model="kling-v3", error=job.error or "Submit failed"
        )

    task_id = job.job_id
    logger.info("  Kling task ID: %s — polling every %ds...", task_id, KLING_POLL_INTERVAL)

    # Custom polling loop with correct I2V endpoint
    for attempt in range(KLING_POLL_MAX):
        time.sleep(KLING_POLL_INTERVAL)
        try:
            # Use the correct endpoint for image2video
            response = client._request("GET", f"/v1/videos/image2video/{task_id}")
            data = response.get("data", {})
            status = data.get("task_status", "unknown")

            logger.info("  Kling poll #%d: status=%s", attempt + 1, status)

            if status == "succeed":
                videos = data.get("task_result", {}).get("videos", [])
                if not videos:
                    return GenerationResult(
                        success=False, model="kling-v3",
                        error="Kling: no videos in result",
                    )
                video_url = videos[0].get("url")
                if not video_url:
                    return GenerationResult(
                        success=False, model="kling-v3",
                        error="Kling: no video URL in result",
                    )

                # Download video
                logger.info("  Kling: downloading video...")
                req = urllib.request.Request(video_url)
                with urllib.request.urlopen(req, timeout=60) as resp:
                    video_data = resp.read()

                # Kling V3 standard: $0.10/sec
                kling_duration = KlingClient._round_duration(duration)
                cost = 0.10 * kling_duration

                return GenerationResult(
                    success=True,
                    video_data=video_data,
                    video_url=video_url,
                    model="kling-v3",
                    cost=cost,
                )

            elif status == "failed":
                error_msg = data.get("task_status_msg", "Kling generation failed")
                return GenerationResult(
                    success=False, model="kling-v3", error=error_msg,
                )

            # Still processing — continue polling

        except Exception as e:
            logger.warning("  Kling poll error: %s", e)
            # Continue polling — transient errors are common

    return GenerationResult(
        success=False, model="kling-v3",
        error=f"Kling timed out after {KLING_POLL_MAX * KLING_POLL_INTERVAL}s",
    )


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

def main():
    kling_only = "--kling-only" in sys.argv

    # Validate environment
    if not kling_only and not os.environ.get("GEMINI_API_KEY"):
        logger.error("GEMINI_API_KEY not set. Aborting.")
        sys.exit(1)

    # Load plan
    logger.info("Loading plan: %s", PLAN_PATH)
    with open(PLAN_PATH) as f:
        plan = json.load(f)

    shots = plan["shots"]
    logger.info("Loaded %d shots from plan", len(shots))

    # Create output directory
    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
    logger.info("Output directory: %s", OUTPUT_DIR)

    if kling_only:
        logger.info("Mode: Kling-only (skipping Veo)")
    else:
        logger.info("Mode: Veo 3.1 primary, Kling V3 fallback")

    # Initialize clients
    veo_client = None
    if not kling_only:
        veo_client = GoogleGenaiClient()
        if not veo_client.is_available():
            logger.warning("Veo client not available, falling back to Kling-only")
            kling_only = True

    # Process each shot
    results = {}
    total_cost = 0.0
    total_generated = 0

    for i, shot in enumerate(shots):
        shot_id = shot["shot_id"]
        output_filename = SHOT_FILENAME_MAP.get(shot_id)
        keyframe_filename = SHOT_KEYFRAME_MAP.get(shot_id)

        if not output_filename or not keyframe_filename:
            logger.warning("Shot %s: no filename mapping, skipping", shot_id)
            results[shot_id] = {"success": False, "error": "No filename mapping"}
            continue

        keyframe_path = FRAMES_DIR / keyframe_filename
        output_path = OUTPUT_DIR / output_filename

        # Skip already-generated shots
        if output_path.exists() and output_path.stat().st_size > 10000:
            file_size = output_path.stat().st_size
            logger.info("[%d/%d] Shot %s — SKIPPED (already exists: %.1f MB)",
                       i + 1, len(shots), shot_id, file_size / 1024 / 1024)
            results[shot_id] = {
                "success": True,
                "cost": 0.0,
                "file_size": file_size,
                "model": "cached",
                "note": "already existed on disk",
            }
            total_generated += 1
            continue

        logger.info("=" * 60)
        logger.info("[%d/%d] Shot %s", i + 1, len(shots), shot_id)
        logger.info("  Keyframe: %s", keyframe_path.name)
        logger.info("  Output:   %s", output_path.name)

        # Verify keyframe exists
        if not keyframe_path.exists():
            logger.error("  Keyframe not found: %s", keyframe_path)
            results[shot_id] = {"success": False, "error": "Keyframe not found"}
            continue

        duration = shot["routing_data"]["target_editorial_duration_s"]
        logger.info("  Target duration: %ds", duration)

        # ---- Try Veo 3.1 first ----
        veo_succeeded = False
        if not kling_only:
            veo_duration = max(VEO_MIN_DURATION, min(VEO_MAX_DURATION, duration))
            if veo_duration != duration:
                logger.info("  Veo duration clamped: %ds -> %ds", duration, veo_duration)

            prompt = build_video_prompt(shot)
            logger.info("  Veo prompt: %s", prompt[:120] + ("..." if len(prompt) > 120 else ""))

            payload = {
                "model": VEO_MODEL,
                "prompt": prompt,
                "start_frame": str(keyframe_path),
                "duration": veo_duration,
                "aspect_ratio": ASPECT_RATIO,
            }

            start_time = time.time()
            try:
                logger.info("  Submitting to Veo 3.1...")
                job = veo_client.submit(payload)

                if job.status == "failed":
                    logger.warning("  Veo submit failed: %s", job.error)
                else:
                    logger.info("  Veo job submitted (ID: %s). Polling...", job.job_id)
                    result = veo_client.wait_for_job(job, timeout_s=TIMEOUT_S)
                    elapsed = time.time() - start_time

                    if result.success and result.video_data:
                        output_path.write_bytes(result.video_data)
                        file_size = output_path.stat().st_size
                        # Veo 3.1 cost: $0.05/sec
                        cost = 0.05 * veo_duration
                        total_cost += cost
                        total_generated += 1
                        veo_succeeded = True

                        logger.info("  VEO SUCCESS: %s (%.1f MB, $%.3f, %.0fs)",
                                   output_filename, file_size / 1024 / 1024, cost, elapsed)
                        results[shot_id] = {
                            "success": True,
                            "cost": cost,
                            "file_size": file_size,
                            "elapsed": elapsed,
                            "model": "veo-3.1",
                        }
                    else:
                        logger.warning("  Veo failed: %s (%.0fs)", result.error, elapsed)

            except Exception as e:
                logger.warning("  Veo exception: %s", e)

        # ---- Kling fallback ----
        if not veo_succeeded:
            logger.info("  Using Kling V3 I2V...")
            start_time = time.time()

            try:
                result = kling_i2v_generate(shot, keyframe_path)
                elapsed = time.time() - start_time

                if result.success and result.video_data:
                    output_path.write_bytes(result.video_data)
                    file_size = output_path.stat().st_size
                    cost = result.cost or (0.10 * KlingClient._round_duration(duration))
                    total_cost += cost
                    total_generated += 1

                    logger.info("  KLING SUCCESS: %s (%.1f MB, $%.3f, %.0fs)",
                               output_filename, file_size / 1024 / 1024, cost, elapsed)
                    results[shot_id] = {
                        "success": True,
                        "cost": cost,
                        "file_size": file_size,
                        "elapsed": elapsed,
                        "model": "kling-v3",
                    }
                else:
                    error = result.error or "No video data"
                    logger.error("  KLING FAILED: %s (%.0fs)", error, elapsed)
                    results[shot_id] = {
                        "success": False,
                        "error": error,
                        "model": "kling-v3",
                    }

            except Exception as e:
                elapsed = time.time() - start_time
                logger.error("  KLING EXCEPTION: %s (%.0fs)", e, elapsed)
                results[shot_id] = {
                    "success": False,
                    "error": str(e),
                    "model": "kling-v3",
                }

        # Brief pause between shots
        if i < len(shots) - 1:
            logger.info("  Waiting 5s before next shot...")
            time.sleep(5)

    # -----------------------------------------------------------------------
    # Summary
    # -----------------------------------------------------------------------
    logger.info("")
    logger.info("=" * 60)
    logger.info("GENERATION COMPLETE")
    logger.info("=" * 60)
    logger.info("")
    logger.info("Total videos: %d / %d", total_generated, len(shots))
    logger.info("Total cost:   $%.2f", total_cost)
    logger.info("")

    # File details
    logger.info("Results per shot:")
    for shot_id in SHOT_FILENAME_MAP:
        r = results.get(shot_id, {})
        filename = SHOT_FILENAME_MAP[shot_id]
        if r.get("success"):
            size_mb = r.get("file_size", 0) / 1024 / 1024
            model = r.get("model", "?")
            cost = r.get("cost", 0)
            note = r.get("note", "")
            extra = f" ({note})" if note else ""
            logger.info("  %-14s %-14s %5.1f MB  $%.3f  [%s]%s",
                       shot_id, filename, size_mb, cost, model, extra)
        else:
            logger.info("  %-14s %-14s FAILED: %s",
                       shot_id, filename, r.get("error", "unknown"))

    # Failures
    final_failures = [sid for sid, r in results.items() if not r.get("success")]
    if final_failures:
        logger.info("")
        logger.info("FAILED SHOTS (%d):", len(final_failures))
        for sid in final_failures:
            logger.info("  %s: %s", sid, results[sid].get("error", "unknown"))

    # Files on disk
    logger.info("")
    logger.info("Output directory: %s", OUTPUT_DIR)
    mp4_files = sorted(OUTPUT_DIR.glob("shot_*.mp4"))
    logger.info("shot_*.mp4 files on disk: %d", len(mp4_files))
    for f in mp4_files:
        logger.info("  %s  (%.1f MB)", f.name, f.stat().st_size / 1024 / 1024)

    # Save manifest
    manifest_path = OUTPUT_DIR / "generation_manifest.json"
    manifest = {
        "episode": EPISODE,
        "project": PROJECT,
        "total_shots": len(shots),
        "total_generated": total_generated,
        "total_cost_usd": round(total_cost, 3),
        "generated_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
        "results": results,
    }
    with open(manifest_path, "w") as f:
        json.dump(manifest, f, indent=2)
    logger.info("Manifest saved: %s", manifest_path)

    return 0 if total_generated == len(shots) else 1


if __name__ == "__main__":
    sys.exit(main())
