#!/usr/bin/env python3
"""
generate_previs.py — Flash 3.1 previs frame generator.

Generates 1x 9:16 frame per shot via Flash 3.1 ($0.039/frame) as a
mandatory previs gate (ADR H02) before keyframe/video generation.

Each previs frame:
  - Uses build_previs_prompt() for a simplified mechanical prompt
  - Runs Gate 1 text critique on the authored prompt (character accuracy)
  - Runs Gate 1 mechanical QC on the result image
  - Saves to sequences/ep_{NNN}/shot_{NNN}.png

Cost per episode (~39 shots): ~$1.52

Usage:
    python -m tools.generate_previs --episode 1
    python -m tools.generate_previs --episode 1 --shots 1-5
    python -m tools.generate_previs --episode 1 --dry-run
"""

from __future__ import annotations

import json
import logging
import os
import sys
import time
from pathlib import Path
from typing import Optional

_PROJECT_ROOT = Path(__file__).parent.parent
if str(_PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(_PROJECT_ROOT))

from recoil.core.paths import DEFAULT_PROJECT, ProjectPaths
from recoil.core.exceptions import WorkspaceStateCorruptError
from recoil.pipeline._lib.previz_context import build_previz_context
from recoil.core.model_profiles import get_model
from recoil.execution.execution_store import ExecutionStore

logger = logging.getLogger(__name__)

PREVIS_MODEL = get_model("exploration", "image")
PREVIS_COST = 0.039  # per frame

NBP_MODEL = get_model("production", "image")
NBP_COST = 0.134  # per frame


def _load_bible(project: str) -> Optional[dict]:
    """Load the global bible for location enrichment.

    Returns None when no bible exists yet (legitimate — previs degrades
    gracefully without it). Raises WorkspaceStateCorruptError when a bible
    file EXISTS but is unreadable/malformed, rather than silently dropping
    all location enrichment from a corrupt bible (REC-232).
    """
    bible_path = ProjectPaths.for_project(project).global_bible_path
    if not bible_path.exists():
        return None
    try:
        return json.loads(bible_path.read_text(encoding="utf-8"))
    except (json.JSONDecodeError, OSError) as e:
        raise WorkspaceStateCorruptError(
            str(bible_path), kind="global bible", message=str(e)
        ) from e


def _run_gate_1_text_for_previs(authored: str, shot: dict) -> dict:
    """Run Gate 1 text and fail closed if the validator is unavailable."""
    try:
        from recoil.pipeline._lib.validation import Validator
        _validator = Validator()
        _g1t = _validator.run_gate_1_text(authored, shot)
        return {"passed": _g1t.passed, "details": _g1t.details, "cost": _g1t.cost}
    except Exception as _e:
        logger.error("Gate 1 text unavailable: %s", _e)
        return {
            "passed": False,
            "details": {"error": f"Gate 1 text unavailable: {_e}"},
            "cost": 0.0,
        }


def generate_previs(
    episode: int,
    shot_ids: Optional[list] = None,
    project: Optional[str] = None,
    dry_run: bool = False,
) -> list[dict]:
    """Generate previs frames for an episode.

    Args:
        episode: Episode number.
        shot_ids: Specific shot indices to generate (None = all).
        project: Project name (for plan loading).
        dry_run: Preview without API calls.

    Returns:
        List of dicts: [{shot_id, status, path, cost, gate_1_text, gate_1_result}]
    """
    # Load plan from project-specific path
    plans_dir = ProjectPaths.for_project(project or DEFAULT_PROJECT).plans_dir
    plan_path = plans_dir / f"ep_{episode:03d}_plan.json"
    if not plan_path.exists():
        raise FileNotFoundError(
            f"Episode plan not found: {plan_path}. "
            "Run extraction pipeline first."
        )

    plan = json.loads(plan_path.read_text())
    shots = plan.get("shots", [])

    if not shots:
        raise ValueError(f"No shots in plan for EP{episode:03d}")

    # Load bible for location enrichment (fail loud on a corrupt-but-present
    # bible; None when absent — REC-232)
    bible = _load_bible(project or DEFAULT_PROJECT)

    # Filter to requested shot_ids (supports ints like 2 and strings like "2A")
    if shot_ids is not None:
        # Build set of full shot_id patterns to match against
        match_ids = set()
        match_nums = set()
        for sid in shot_ids:
            if isinstance(sid, str):
                # Letter suffix like "2A" → match "EP001_SH02A"
                match_ids.add(f"EP{episode:03d}_SH{int(sid[:-1]):02d}{sid[-1]}")
            else:
                match_nums.add(sid)
                match_ids.add(f"EP{episode:03d}_SH{sid:02d}")
        shots = [
            s for s in shots
            if s.get("shot_id") in match_ids
            or _extract_shot_num(s.get("shot_id", "")) in match_nums
        ]

    # Prepare output directory (project-specific) — previs and keyframes
    # share the sequences/ep_NNN/ dir under v2 layout.
    ep_previs_dir = ProjectPaths.for_project(project).episode_prep_dir(episode)
    if not dry_run:
        ep_previs_dir.mkdir(parents=True, exist_ok=True)

    logger.info(
        "Previs: EP%03d — %d shots (%s)",
        episode, len(shots), "DRY RUN" if dry_run else "LIVE",
    )

    # Open execution store for registration
    store = None
    if not dry_run:
        try:
            store = ExecutionStore(project=project or DEFAULT_PROJECT)
        except Exception as e:
            logger.warning("Could not open ExecutionStore: %s — frames won't be registered", e)

    results = []
    total_cost = 0.0

    for i, shot in enumerate(shots, 1):
        shot_id = shot.get("shot_id", f"EP{episode:03d}_SH{i:02d}")
        shot_label = _extract_shot_label(shot_id)

        # Always write to a new take file — never overwrite previous takes
        output_path = _next_take_path(ep_previs_dir, shot_label)

        # Build full context for Flash-authored previz
        context_parts = build_previz_context(
            shot=shot,
            all_shots=shots,
            bible=bible,
            episode=episode,
            project=project,
        )

        if dry_run:
            # Show the system instruction (last text part) for preview
            last_text = next(
                (text for data, kind, text in reversed(context_parts) if data is None),
                "no instruction",
            )
            n_images = sum(1 for data, _, _ in context_parts if data is not None)
            logger.info("  [%d/%d] %s — %d refs, instruction: %s",
                        i, len(shots), shot_id, n_images, last_text[:80])
            results.append({
                "shot_id": shot_id,
                "status": "dry_run",
                "context_parts": len(context_parts),
                "image_refs": n_images,
                "cost": PREVIS_COST,
            })
            total_cost += PREVIS_COST
            continue

        n_images = sum(1 for data, _, _ in context_parts if data is not None)
        logger.info("  [%d/%d] %s (%d context parts, %d images)",
                     i, len(shots), shot_id, len(context_parts), n_images)

        gen_result = _generate_flash_frame(context_parts=context_parts)

        if gen_result["success"]:
            # Log the authored prompt
            authored = gen_result.get("authored_prompt", "")
            if authored:
                logger.info("  Flash authored prompt: %s", authored[:120])

            # Run Gate 1 text critique on authored prompt
            if authored:
                gate_1_text = _run_gate_1_text_for_previs(authored, shot)
            else:
                gate_1_text = {"passed": True, "reason": "No authored prompt", "cost": 0.0}

            shot_cost = PREVIS_COST
            gate_cost = gate_1_text.get("cost", 0.0)

            if not gate_1_text.get("passed", True):
                # Gate 1 failed — retry generation once with corrections
                logger.warning("  Gate 1 FAIL for %s: extras=%s, missing=%s — retrying",
                               shot_id, gate_1_text.get("extras"), gate_1_text.get("missing"))
                gen_result = _generate_flash_frame(context_parts=context_parts)
                shot_cost += PREVIS_COST  # Second generation attempt
                if gen_result["success"]:
                    authored = gen_result.get("authored_prompt", "")
                    # Re-run Gate 1 on retry (but don't retry again)
                    if authored:
                        gate_1_text = _run_gate_1_text_for_previs(authored, shot)
                    else:
                        gate_1_text = {"passed": True, "cost": 0.0}
                    gate_cost += gate_1_text.get("cost", 0.0)

            if gen_result["success"]:
                # Save frame
                output_path.write_bytes(gen_result["image_data"])

                # Run Gate 1 mechanical QC
                try:
                    from recoil.pipeline._lib.validation import Validator
                    _validator = Validator()
                    _g1m = _validator.run_gate_1_image(output_path)
                    gate_1_mechanical = {"passed": _g1m.passed, "details": _g1m.details, "cost": _g1m.cost}
                except Exception as _e:
                    gate_1_mechanical = {"passed": True, "reason": f"Gate 1 unavailable: {_e}", "cost": 0.0}

                # Gate 2 VQA now runs via pipeline gates, not inline
                gate_2 = {"passed": True, "cost": 0.0}

                # Determine overall status
                if not gate_1_text.get("passed", True):
                    overall_status = "qc_failed"
                elif not gate_1_mechanical.get("passed", True):
                    overall_status = "qc_failed"
                else:
                    overall_status = "ok"

                results.append({
                    "shot_id": shot_id,
                    "status": overall_status,
                    "path": str(output_path),
                    "cost": shot_cost + gate_cost,
                    "gate_1_text": gate_1_text,
                    "gate_1_result": gate_1_mechanical,
                    "gate_2": gate_2,
                    "authored_prompt": authored,
                })

                # Register in execution store
                if store:
                    proj_root = ProjectPaths.for_project(project or DEFAULT_PROJECT).project_root
                    try:
                        rel_path = str(output_path.relative_to(proj_root))
                    except ValueError:
                        rel_path = str(output_path)
                    take_entry = {
                        "take_id": f"{shot_id}_T{int(time.time()) % 100000:05d}",
                        "file_path": rel_path,
                        "cost": shot_cost + gate_cost,
                        "timestamp": time.time(),
                    }
                    store_status = "previs_generated" if overall_status == "ok" else "previs_pending"
                    # Ensure valid state machine path: must go through previs_generating first
                    current = (store.get_shot(shot_id) or {}).get("status", "")
                    if current not in ("previs_generating",) and store_status == "previs_generated":
                        store.update_shot(shot_id, episode_id=f"EP{episode:03d}",
                                          pipeline="still", model=PREVIS_MODEL,
                                          status="previs_generating")
                    store.update_shot(
                        shot_id,
                        episode_id=f"EP{episode:03d}",
                        pipeline="still",
                        model=PREVIS_MODEL,
                        status=store_status,
                        append_take=take_entry,
                        cost_incurred=shot_cost + gate_cost,
                    )
            else:
                results.append({
                    "shot_id": shot_id,
                    "status": "failed",
                    "error": gen_result.get("error", "Unknown error after Gate 1 retry"),
                    "cost": shot_cost + gate_cost,
                    "gate_1_text": gate_1_text,
                })
        else:
            results.append({
                "shot_id": shot_id,
                "status": "failed",
                "error": gen_result.get("error", "Unknown error"),
                "cost": PREVIS_COST,
            })

        total_cost += PREVIS_COST

    # Summary
    ok_count = sum(1 for r in results if r.get("status") == "ok")
    failed_count = sum(1 for r in results if r.get("status") in ("failed", "qc_failed"))

    logger.info(
        "Previs complete: %d/%d ok, %d failed, cost: $%.2f",
        ok_count, len(results), failed_count, total_cost,
    )

    # Save previs summary
    if not dry_run:
        summary_path = ep_previs_dir / "previs_summary.json"
        summary = {
            "episode": episode,
            "total_shots": len(results),
            "ok": ok_count,
            "failed": failed_count,
            "total_cost": round(total_cost, 4),
            "results": results,
        }
        summary_path.write_text(json.dumps(summary, indent=2, default=str))

    return results


def _generate_flash_frame(
    prompt: str | None = None,
    ref_images: list | None = None,
    context_parts: list | None = None,
) -> dict:
    """Generate a single previs frame via Flash 3.1.

    Supports two modes:
      1. Full context (new): context_parts from build_previz_context().
         Each part is (bytes|None, mime_or_kind, text_or_label).
         Flash writes its own prompt and generates in one call.
      2. Legacy: prompt string + optional ref_images (bytes, mime, label).

    Returns:
        {"success": True, "image_data": bytes, "authored_prompt": str}
        or {"success": False, "error": str}
    """
    try:
        from google import genai
        from google.genai import types as genai_types

        api_key = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY")
        if not api_key:
            return {"success": False, "error": "GEMINI_API_KEY not set"}

        client = genai.Client(api_key=api_key)

        from recoil.core.paths import get_config as _get_cfg
        _previz_temp = _get_cfg().get("previz_temperature", 0.4)
        config = genai_types.GenerateContentConfig(
            temperature=_previz_temp,
            responseModalities=["IMAGE", "TEXT"],
            imageConfig=genai_types.ImageConfig(
                aspectRatio="9:16",
            ),
        )

        # Build multimodal content
        if context_parts:
            # Full context mode: assemble from structured parts
            contents = []
            for data, kind, text_or_label in context_parts:
                if data is None:
                    # Text part
                    contents.append(genai_types.Part.from_text(text=text_or_label))
                else:
                    # Image part: data=bytes, kind=mime_type, text_or_label=label
                    contents.append(genai_types.Part.from_bytes(data=data, mime_type=kind))
                    if text_or_label:
                        contents.append(genai_types.Part.from_text(text=text_or_label))
        elif ref_images:
            # Legacy mode: ref images + prompt
            contents = []
            for img_bytes, mime_type, label in ref_images:
                contents.append(genai_types.Part.from_bytes(data=img_bytes, mime_type=mime_type))
                contents.append(genai_types.Part.from_text(text=label))
            contents.append(genai_types.Part.from_text(text=prompt))
        else:
            contents = prompt

        response = client.models.generate_content(
            model=PREVIS_MODEL,
            contents=contents,
            config=config,
        )

        # Extract BOTH text (authored prompt) and image from response
        image_data = None
        authored_prompt = ""

        if response and response.candidates:
            for candidate in response.candidates:
                if candidate.content and candidate.content.parts:
                    for part in candidate.content.parts:
                        if hasattr(part, "inline_data") and part.inline_data:
                            image_data = part.inline_data.data
                        elif hasattr(part, "text") and part.text:
                            authored_prompt += part.text

        if image_data:
            return {
                "success": True,
                "image_data": image_data,
                "authored_prompt": authored_prompt.strip(),
            }

        return {"success": False, "error": "No image in response"}

    except Exception as e:
        return {"success": False, "error": str(e)}


def _generate_nbp_frame(
    prompt: str,
    ref_images: list[tuple[bytes, str, str]] | None = None,
) -> dict:
    """Generate a production keyframe via NBP (Gemini 3 Pro).

    Same API pattern as _generate_flash_frame() but optimized for production:
    - Model: gemini-3-pro-image-preview (NBP)
    - Temperature: 0.3 (more deterministic for production quality)
    - Response: IMAGE only (no text authoring — prompt is pre-built)
    - Max 5 reference images (NBP embedding averaging dilutes faces beyond 5)
    - Retry with exponential backoff on ResourceExhausted/InternalServerError

    Args:
        prompt: The NBP-optimized prompt (from build_smart_prompt or director).
        ref_images: Optional list of (bytes, mime_type, label) reference images.
                    Ordered by recency bias (least important first, hero last).

    Returns:
        {"success": True, "image_data": bytes}
        or {"success": False, "error": str}
    """
    try:
        from tenacity import (
            retry,
            stop_after_attempt,
            wait_exponential,
            retry_if_exception_type,
        )
    except ImportError:
        # Fallback if tenacity not installed — no retry
        return _generate_nbp_frame_inner(prompt, ref_images)

    from google.api_core.exceptions import ResourceExhausted, InternalServerError

    @retry(
        stop=stop_after_attempt(5),
        wait=wait_exponential(multiplier=1, min=4, max=30),
        retry=retry_if_exception_type((ResourceExhausted, InternalServerError)),
        reraise=True,
    )
    def _call_with_retry():
        return _generate_nbp_frame_inner(prompt, ref_images, raise_on_quota=True)

    try:
        return _call_with_retry()
    except (ResourceExhausted, InternalServerError) as e:
        return {"success": False, "error": f"NBP exhausted after 5 retries: {e}"}
    except Exception as e:
        return {"success": False, "error": str(e)}


def _generate_nbp_frame_inner(
    prompt: str,
    ref_images: list[tuple[bytes, str, str]] | None = None,
    raise_on_quota: bool = False,
) -> dict:
    """Inner NBP generation — called directly or via tenacity retry."""
    try:
        from google import genai
        from google.genai import types as genai_types

        api_key = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY")
        if not api_key:
            return {"success": False, "error": "GEMINI_API_KEY not set"}

        client = genai.Client(api_key=api_key)

        from recoil.core.paths import get_config as _get_cfg
        _previz_temp2 = _get_cfg().get("previz_temperature", 0.4)
        config = genai_types.GenerateContentConfig(
            temperature=_previz_temp2,
            responseModalities=["IMAGE"],
            imageConfig=genai_types.ImageConfig(
                aspectRatio="9:16",
            ),
        )

        # Build multimodal content: refs first (recency bias), then prompt
        contents = []
        if ref_images:
            # Cap at 5 refs (NBP embedding averaging dilutes faces beyond 5)
            for img_bytes, mime_type, label in ref_images[:5]:
                contents.append(genai_types.Part.from_bytes(data=img_bytes, mime_type=mime_type))
                if label:
                    contents.append(genai_types.Part.from_text(text=label))

        contents.append(genai_types.Part.from_text(text=prompt))

        response = client.models.generate_content(
            model=NBP_MODEL,
            contents=contents,
            config=config,
        )

        # Extract image from response
        image_data = None
        if response and response.candidates:
            for candidate in response.candidates:
                if candidate.content and candidate.content.parts:
                    for part in candidate.content.parts:
                        if hasattr(part, "inline_data") and part.inline_data:
                            image_data = part.inline_data.data

        if image_data:
            return {"success": True, "image_data": image_data}

        return {"success": False, "error": "No image in NBP response"}

    except Exception as e:
        if raise_on_quota:
            # Let tenacity handle ResourceExhausted/InternalServerError
            from google.api_core.exceptions import ResourceExhausted, InternalServerError
            if isinstance(e, (ResourceExhausted, InternalServerError)):
                raise
        return {"success": False, "error": str(e)}


def _next_take_path(ep_dir: Path, shot_label: str) -> Path:
    """Return the next available take path without overwriting anything.

    Scans for shot_{label}.png and shot_{label}_take*.png, returns
    the next numbered take path.
    """
    import re
    base = ep_dir / f"shot_{shot_label}.png"
    existing_takes = sorted(ep_dir.glob(f"shot_{shot_label}_take*.png"))
    if not base.exists() and not existing_takes:
        # First ever take — use base name
        return base
    # Find highest existing take number
    highest = 0
    for t in existing_takes:
        m = re.search(r"_take(\d+)", t.stem)
        if m:
            highest = max(highest, int(m.group(1)))
    # If base exists but no takes yet, base is implicitly take 0
    if base.exists() and highest == 0:
        highest = 0
    next_take = highest + 1
    return ep_dir / f"shot_{shot_label}_take{next_take:03d}.png"


def _archive_as_take(output_path: Path, shot_label: str, ep_dir: Path) -> None:
    """Move existing previz frame to a numbered take before overwriting.

    shot_003.png → shot_003_take001.png (or _take002 if 001 exists, etc.)
    Kept for backward compatibility but no longer called by generate_previs.
    """
    import re
    existing_takes = sorted(ep_dir.glob(f"shot_{shot_label}_take*.png"))
    if existing_takes:
        highest = 0
        for t in existing_takes:
            m = re.search(r"_take(\d+)", t.stem)
            if m:
                highest = max(highest, int(m.group(1)))
        next_take = highest + 1
    else:
        next_take = 1
    take_path = ep_dir / f"shot_{shot_label}_take{next_take:03d}.png"
    output_path.rename(take_path)
    logger.info("  Archived previous → %s", take_path.name)


def _extract_shot_num(shot_id: str) -> int:
    """Extract numeric shot index from shot_id like 'EP001_SH03'.

    For integer comparisons (filtering). Use _extract_shot_label() for filenames.
    """
    import re
    match = re.search(r"SH(\d+)", shot_id)
    if match:
        return int(match.group(1))
    return 0


def _extract_shot_label(shot_id: str) -> str:
    """Extract shot label including letter suffix for filenames.

    'EP001_SH02' → '002', 'EP001_SH02A' → '002a'
    """
    import re
    match = re.search(r"SH(\d+)([A-Za-z]?)", shot_id)
    if match:
        num = int(match.group(1))
        suffix = match.group(2).lower()
        return f"{num:03d}{suffix}"
    return "000"


# ── CLI ──────────────────────────────────────────────────────────────

def main():
    import argparse

    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s [%(levelname)s] %(message)s",
        datefmt="%H:%M:%S",
    )

    parser = argparse.ArgumentParser(description="Starsend Previs Generator")
    parser.add_argument("--episode", "-e", type=int, required=True, help="Episode number")
    parser.add_argument("--shots", "-s", type=str, default=None,
                        help="Shot IDs (e.g. '1-5', '1,3,7', '2A', '2,2A,3,3A')")
    parser.add_argument("--project", "-p", type=str, default=None, help="Project name")
    parser.add_argument("--dry-run", "-d", action="store_true", help="Preview without API calls")

    args = parser.parse_args()

    # Parse shot IDs — supports integers, ranges, and letter suffixes (e.g. 2A, 3A)
    shot_ids = None
    if args.shots:
        shot_ids = []
        for part in args.shots.split(","):
            part = part.strip()
            if "-" in part and not any(c.isalpha() for c in part):
                # Pure numeric range: "1-5"
                start, end = part.split("-", 1)
                shot_ids.extend(range(int(start), int(end) + 1))
            elif part[-1:].isalpha() and part[:-1].isdigit():
                # Letter suffix: "2A" → store as string
                shot_ids.append(part.upper())
            else:
                shot_ids.append(int(part))

    results = generate_previs(
        episode=args.episode,
        shot_ids=shot_ids,
        project=args.project,
        dry_run=args.dry_run,
    )

    print(json.dumps(results, indent=2, default=str))


if __name__ == "__main__":
    main()
