#!/usr/bin/env python3
"""
prep_location_refs.py — Generate location refs from GlobalBible.

Two modes:
  Legacy (default): 3-angle reverse-angle chaining (wide → medium → closeup).
  Shot-aware (--shot-aware): 2-bin focal-length binning from shot plans.
    - Wide (<35mm): text-only establishing shot
    - Mid (35-85mm): radial-chained from wide, driven by first shot in bin
    - ECU (>85mm): no location ref (location_view_id = null)

Shot-aware mode writes location_view_id back to shot plan JSON (static binding).

Usage:
    # Legacy mode
    python -m tools.prep_location_refs --project tartarus
    python -m tools.prep_location_refs --project tartarus --wide-only

    # Shot-aware mode
    python -m tools.prep_location_refs --project starsend-test --shot-aware
    python -m tools.prep_location_refs --project starsend-test --shot-aware --episode EP001
    python -m tools.prep_location_refs --project starsend-test --shot-aware --dry-run
"""

import argparse
import json
import logging
import os
import sys
import time
from pathlib import Path
from typing import Optional
from datetime import datetime

sys.path.insert(0, str(Path(__file__).parent.parent))

from recoil.core.paths import ProjectPaths
from recoil.core.model_profiles import get_model
FLASH_MODEL = get_model("exploration", "image")
FLASH_COST = 0.039

# 3 angles per location
ANGLES = [
    ("wide", "Wide cinematic establishing shot, full environment visible."),
    ("medium", "Reverse angle medium shot of the same location, different vantage point, "
               "showing depth and architectural detail from the opposite direction."),
    ("closeup", "Close-up detail shot at a junction or focal point of the same location, "
                "showing texture, material quality, and atmospheric detail."),
]

# ── Shot-aware prompt templates (ADR-L03) ────────────────────────────

SYSTEM_INSTRUCTION = (
    "You are a cinematic location scout creating empty stage plates for a film production. "
    "Do not include any people, characters, human figures, or narrative props. "
    "Generate only the empty architectural environment. Maintain a strict 9:16 vertical aspect ratio."
)

WIDE_TEMPLATE = """Cinematic empty stage plate. Strictly architectural, no characters or narrative props.

LOCATION: {bible_name}
DESCRIPTION: {bible_desc}
ATMOSPHERE: {bible_lighting}. {bible_mood}. {aesthetic_tone}

CAMERA SETUP:
- Shot Type: Wide Establishing Shot
- Lens: {lens}mm
- Position: {camera_position}

COMPOSITION:
Deep depth of field. The entire architectural space is visible. Emphasize the verticality of the 9:16 frame. The immediate foreground consists of an unobstructed bare floor with clear leading lines, leaving intentional negative space in the bottom half of the frame. No clutter or objects blocking the camera lens."""

MID_TEMPLATE = """Cinematic empty stage plate. Strictly architectural, no characters or narrative props.

LOCATION: {bible_name}
DESCRIPTION: {bible_desc}
ATMOSPHERE: {bible_lighting}. {bible_mood}. {aesthetic_tone}

CAMERA SETUP:
- Shot Type: Medium / Mid-Corridor Shot
- Lens: {lens}mm
- Position: {camera_position}

COMPOSITION:
Medium depth of field focusing on the midground architectural details. The immediate foreground is an unobstructed bare floor, providing a clear line of sight to the midground. Do not place objects, crates, or debris in the lower third of the frame. The space feels lived-in but is currently completely empty of people."""

# Focal length bin boundaries
_BIN_WIDE_MAX = 34   # <35mm = wide
_BIN_MID_MAX = 85    # 35-85mm = mid
# >85mm = ECU → no location ref

logger = logging.getLogger("starsend.prep_location_refs")


def _parse_focal_length(shot: dict, default: int = 50) -> int:
    """Safely extract focal length as int from a shot dict."""
    pd = shot.get("prompt_data") or {}
    fl_str = pd.get("focal_length", f"{default}mm")
    try:
        return int(str(fl_str).replace("mm", ""))
    except (ValueError, TypeError):
        return default


def generate_location_refs(
    bible: dict,
    project: Optional[str] = None,
    dry_run: bool = False,
    wide_only: bool = False,
) -> list:
    """Generate 3-angle Flash 3.1 refs per location (ADR-L01/L02).

    Args:
        bible: Global bible dict (from state/visual/global_bible.json).
        project: Project name (for logging).
        dry_run: If True, log prompts but don't call API.
        wide_only: If True, only generate the wide establishing shot (legacy mode).

    Returns:
        List of dicts: [{location_id, angle, path, cost}]
    """
    from recoil.core.paths import get_config
    _cfg = get_config()
    aspect_ratio = _cfg.get("production_aspect_ratio", "9:16")

    locations = bible.get("locations", {})
    if not locations:
        logger.warning("No locations in bible")
        return []

    angles_to_generate = [ANGLES[0]] if wide_only else ANGLES
    results = []

    for loc_id, loc_data in locations.items():
        slug = _slugify(loc_id)
        class_folder = ProjectPaths.for_project(project).asset_subject_dir("loc", slug)
        description = loc_data.get("description", loc_id)
        lighting = loc_data.get("lighting_profile", "")
        palette = loc_data.get("color_palette", [])
        atmosphere = loc_data.get("atmosphere", "")

        prev_image_bytes = None  # For reverse-angle chaining

        for angle_name, angle_direction in angles_to_generate:
            ref_path = class_folder / f"{slug}_{angle_name}.png"

            # Skip if this angle already exists
            if ref_path.exists():
                logger.info(f"  Skipping {loc_id}/{angle_name} (exists)")
                # Load existing for chaining to next angle
                if not dry_run:
                    prev_image_bytes = ref_path.read_bytes()
                continue

            # Also check legacy moodboard path for wide angle
            if angle_name == "wide":
                legacy_path = class_folder / f"{slug}_moodboard.png"
                if legacy_path.exists():
                    logger.info(f"  Skipping {loc_id}/{angle_name} (legacy moodboard exists)")
                    if not dry_run:
                        prev_image_bytes = legacy_path.read_bytes()
                    continue

            # Build prompt
            prompt_parts = [
                f"Cinematic environment concept art, {aspect_ratio} {'vertical' if aspect_ratio == '9:16' else 'horizontal'} frame.",
                f"Location: {description}.",
                angle_direction,
            ]

            if atmosphere:
                prompt_parts.append(f"Atmosphere: {atmosphere}.")
            if lighting:
                prompt_parts.append(f"Lighting: {lighting}.")
            if palette:
                palette_str = ", ".join(palette[:5])
                prompt_parts.append(f"Color palette: {palette_str}.")

            prompt_parts.append(
                "No people, no characters, no figures. "
                "Environment only. Photorealistic, high production value."
            )
            prompt = " ".join(prompt_parts)

            if dry_run:
                chained = "chained" if prev_image_bytes else "text-only"
                logger.info(f"  [DRY RUN] {loc_id}/{angle_name} ({chained}): {prompt[:120]}...")
                results.append({"location_id": loc_id, "angle": angle_name, "path": None, "cost": 0})
                continue

            # Generate — use chaining if we have a previous angle
            if prev_image_bytes is not None:
                image_data = _generate_flash_image_with_ref(prompt, prev_image_bytes, aspect_ratio=aspect_ratio)
            else:
                image_data = _generate_flash_image(prompt, aspect_ratio=aspect_ratio)

            if image_data is None:
                logger.warning(f"  Failed: {loc_id}/{angle_name}")
                continue

            # Save
            ref_path.parent.mkdir(parents=True, exist_ok=True)
            ref_path.write_bytes(image_data)
            logger.info(f"  Saved: {ref_path}")

            results.append({
                "location_id": loc_id,
                "angle": angle_name,
                "path": str(ref_path),
                "cost": FLASH_COST,
            })

            prev_image_bytes = image_data
            time.sleep(2)  # Rate limit

    total_cost = sum(r["cost"] for r in results)
    logger.info(f"Location refs complete: {len(results)} generated, ${total_cost:.3f}")
    return results


def generate_shot_aware_refs(
    bible: dict,
    shot_plans: dict,
    project: Optional[str] = None,
    dry_run: bool = False,
) -> list:
    """Generate shot-aware location refs using focal-length binning.

    Reads shot plans to determine what camera setups actually exist for each
    location, bins them into wide/mid, and generates refs matching those setups.
    ECU shots (>85mm) get no location ref.

    Args:
        bible: Global bible dict.
        shot_plans: Dict of {episode_id: plan_json} with shot arrays.
        project: Project name.
        dry_run: If True, log prompts but don't call API.

    Returns:
        List of dicts with keys: location_id, focal_bin, path, driving_shot_id,
        shots_covered, cost.
    """
    from recoil.core.paths import get_config
    _cfg = get_config()
    aspect_ratio = _cfg.get("production_aspect_ratio", "9:16")
    aesthetic_directives = bible.get("aesthetic_directives") or {}
    aesthetic_tone = aesthetic_directives.get("tone", "")

    locations = bible.get("locations", {})
    if not locations:
        logger.warning("No locations in bible")
        return []

    # ── Step 1: Collect all shots per location ───────────────────────
    location_shots = {}  # {location_id: [shot_dict, ...]}
    for ep_id, plan in shot_plans.items():
        for shot in plan.get("shots", []):
            loc_id = shot.get("asset_data", {}).get("location_id")
            if loc_id and loc_id in locations:
                location_shots.setdefault(loc_id, []).append(shot)

    results = []
    timestamp = datetime.now().strftime("%y%m%d_%H%M")

    for loc_id, loc_data in locations.items():
        slug = _slugify(loc_id)
        class_folder = ProjectPaths.for_project(project).asset_subject_dir("loc", slug)
        class_folder.mkdir(parents=True, exist_ok=True)

        description = loc_data.get("description", loc_id)
        lighting = loc_data.get("lighting", loc_data.get("lighting_profile", ""))
        mood = loc_data.get("mood", loc_data.get("atmosphere", ""))

        shots = location_shots.get(loc_id, [])
        if not shots:
            logger.info(f"  {loc_id}: no shots in plan, skipping")
            continue

        # ── Step 2: Bin shots by focal length ────────────────────────
        bins = {"wide": [], "mid": []}
        for shot in shots:
            fl = _parse_focal_length(shot)
            if fl <= _BIN_WIDE_MAX:
                bins["wide"].append(shot)
            elif fl <= _BIN_MID_MAX:
                bins["mid"].append(shot)
            # fl > _BIN_MID_MAX → ECU, no location ref

        logger.info(f"  {loc_id}: wide={len(bins['wide'])} mid={len(bins['mid'])} "
                     f"ecu={len(shots) - len(bins['wide']) - len(bins['mid'])}")

        # ── Step 3: Generate views (radial chaining from wide) ───────
        wide_bytes = None

        for bin_name in ("wide", "mid"):
            bin_shots = bins[bin_name]
            if not bin_shots:
                continue

            # First-in-bin wins
            driving_shot = bin_shots[0]
            driving_id = driving_shot.get("shot_id", "unknown")
            fl = str(_parse_focal_length(driving_shot))

            # Derive camera position from spatial data
            spatial = driving_shot.get("spatial_data", {})
            camera_side = spatial.get("camera_side", "A")
            screen_dir = spatial.get("screen_direction", "center")
            camera_position = f"Static, eye-level, camera side {camera_side}, looking {screen_dir}"

            # Build prompt from template
            template = WIDE_TEMPLATE if bin_name == "wide" else MID_TEMPLATE
            prompt = template.format(
                bible_name=loc_data.get("name", loc_id),
                bible_desc=description,
                bible_lighting=lighting,
                bible_mood=mood,
                aesthetic_tone=aesthetic_tone,
                lens=fl,
                camera_position=camera_position,
            )

            ref_filename = f"{slug}_{bin_name}_v{timestamp}.png"
            ref_path = class_folder / ref_filename

            covered_ids = [s.get("shot_id", "") for s in bin_shots]

            if dry_run:
                chained = "chained from wide" if bin_name == "mid" and wide_bytes else "text-only"
                logger.info(f"  [DRY RUN] {loc_id}/{bin_name} ({chained}, driving={driving_id})")
                logger.info(f"    Covers: {covered_ids}")
                logger.info(f"    Prompt: {prompt[:120]}...")
                results.append({
                    "location_id": loc_id,
                    "focal_bin": bin_name,
                    "path": None,
                    "driving_shot_id": driving_id,
                    "shots_covered": covered_ids,
                    "cost": 0,
                })
                continue

            # Generate image
            if bin_name == "mid" and wide_bytes is not None:
                # Radial chaining: mid chains from wide
                image_data = _generate_flash_image_with_ref(
                    prompt, wide_bytes, aspect_ratio=aspect_ratio
                )
            else:
                # Wide is always text-only
                image_data = _generate_flash_image(prompt, aspect_ratio=aspect_ratio)

            if image_data is None:
                logger.warning(f"  Failed: {loc_id}/{bin_name}")
                continue

            ref_path.write_bytes(image_data)
            logger.info(f"  Saved: {ref_path}")

            results.append({
                "location_id": loc_id,
                "focal_bin": bin_name,
                "path": str(ref_path),
                "driving_shot_id": driving_id,
                "shots_covered": covered_ids,
                "cost": FLASH_COST,
            })

            # Cache wide bytes for radial chaining
            if bin_name == "wide":
                wide_bytes = image_data

            time.sleep(2)  # Rate limit

    total_cost = sum(r["cost"] for r in results)
    logger.info(f"Shot-aware refs complete: {len(results)} generated, ${total_cost:.3f}")
    return results


def bind_location_views(shot_plans: dict, results: list) -> int:
    """Write location_view_id back to shot plan JSON (static binding).

    For each generated view, sets location_view_id on all covered shots.
    For ECU shots (>85mm), sets location_view_id to null.

    Args:
        shot_plans: Dict of {episode_id: plan_json} — mutated in place.
        results: Output from generate_shot_aware_refs().

    Returns:
        Number of shots updated.
    """
    # Build lookup: shot_id → ref filename
    shot_to_view = {}
    for r in results:
        filename = Path(r["path"]).name if r.get("path") else None
        for shot_id in r.get("shots_covered", []):
            shot_to_view[shot_id] = filename

    updated = 0
    for ep_id, plan in shot_plans.items():
        for shot in plan.get("shots", []):
            shot_id = shot.get("shot_id", "")
            if shot_id in shot_to_view:
                shot["location_view_id"] = shot_to_view[shot_id]
                updated += 1
            else:
                # ECU or unmatched — explicitly null
                fl = _parse_focal_length(shot)
                if fl > _BIN_MID_MAX:
                    shot["location_view_id"] = None
                    updated += 1

    return updated


def _generate_flash_image(prompt: str, aspect_ratio: str = None) -> Optional[bytes]:
    """Generate one image via Flash 3.1 (text-only)."""
    try:
        from google import genai
        from google.genai import types
    except ImportError:
        logger.error("google-genai SDK not installed")
        return None

    if not aspect_ratio:
        from recoil.core.paths import get_config
        aspect_ratio = get_config().get("production_aspect_ratio", "9:16")

    api_key = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY")
    if not api_key:
        logger.error("GEMINI_API_KEY not set")
        return None

    client = genai.Client(api_key=api_key)

    try:
        response = client.models.generate_content(
            model=FLASH_MODEL,
            contents=[types.Part(text=prompt)],
            config=types.GenerateContentConfig(
                response_modalities=["IMAGE", "TEXT"],
                image_config=types.ImageConfig(
                    aspect_ratio=aspect_ratio,
                ),
            ),
        )
    except Exception as e:
        logger.error(f"Flash API error: {e}")
        return None

    return _extract_image_bytes(response)


def _generate_flash_image_with_ref(prompt: str, ref_bytes: bytes, aspect_ratio: str = None) -> Optional[bytes]:
    """Generate image with a scene reference for angle chaining (ADR-L02)."""
    try:
        from google import genai
        from google.genai import types
    except ImportError:
        logger.error("google-genai SDK not installed")
        return None

    if not aspect_ratio:
        from recoil.core.paths import get_config
        aspect_ratio = get_config().get("production_aspect_ratio", "9:16")

    api_key = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY")
    if not api_key:
        logger.error("GEMINI_API_KEY not set")
        return None

    client = genai.Client(api_key=api_key)

    parts = [
        types.Part.from_bytes(data=ref_bytes, mime_type="image/png"),
        types.Part(text="[SCENE ENVIRONMENT REFERENCE]"),
        types.Part(text=prompt),
    ]

    try:
        response = client.models.generate_content(
            model=FLASH_MODEL,
            contents=parts,
            config=types.GenerateContentConfig(
                response_modalities=["IMAGE", "TEXT"],
                image_config=types.ImageConfig(
                    aspect_ratio=aspect_ratio,
                ),
            ),
        )
    except Exception as e:
        logger.error(f"Flash API error (with ref): {e}")
        return None

    return _extract_image_bytes(response)


def _extract_image_bytes(response) -> Optional[bytes]:
    """Extract image bytes from a genai response."""
    if response and response.candidates:
        for candidate in response.candidates:
            if candidate.content and candidate.content.parts:
                for part in candidate.content.parts:
                    if hasattr(part, "inline_data") and part.inline_data:
                        return part.inline_data.data
    return None


from recoil.pipeline._lib.taxonomy import slugify_asset_id as _slugify


def main():
    parser = argparse.ArgumentParser(
        description="Generate location refs from GlobalBible (legacy 3-view or shot-aware 2-view)"
    )
    parser.add_argument("--dry-run", action="store_true")
    parser.add_argument("--project", default=None)
    parser.add_argument("--wide-only", action="store_true",
                        help="Only generate wide establishing shot (legacy mode)")
    parser.add_argument("--shot-aware", action="store_true",
                        help="Use shot-plan-aware binning (reads plans from state dir)")
    parser.add_argument("--episode", type=str, default=None,
                        help="Episode ID to process (e.g. EP001). Default: all episodes.")
    args = parser.parse_args()

    logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s", datefmt="%H:%M:%S")

    # Load bible — project-scoped, no engine-default fallback (fail loud).
    if not args.project:
        sys.exit("ERROR: --project is required (no engine-default bible fallback).")
    bible_path = ProjectPaths.for_project(args.project).global_bible_path

    if not bible_path.exists():
        print(f"ERROR: Global bible not found at {bible_path}")
        print("Run Stage 1 (breakdown pass) first.")
        sys.exit(1)

    bible = json.loads(bible_path.read_text(encoding="utf-8"))

    if args.shot_aware:
        # Load shot plans
        plans_dir = bible_path.parent / "plans"
        if not plans_dir.exists():
            print(f"ERROR: Plans directory not found at {plans_dir}")
            print("Run Stage 2 (plan pass) first.")
            sys.exit(1)

        shot_plans = {}
        for plan_file in sorted(plans_dir.glob("ep_*_plan.json")):
            ep_id = plan_file.stem.replace("_plan", "").upper().replace("EP_", "EP")
            if args.episode and ep_id != args.episode.upper():
                continue
            shot_plans[ep_id] = json.loads(plan_file.read_text(encoding="utf-8"))

        if not shot_plans:
            print(f"ERROR: No shot plans found in {plans_dir}")
            sys.exit(1)

        print(f"Shot-aware mode: {len(shot_plans)} episode(s)")

        # Generate refs
        results = generate_shot_aware_refs(
            bible, shot_plans, project=args.project, dry_run=args.dry_run
        )

        # Static binding — write location_view_id back to plans
        if not args.dry_run and results:
            updated = bind_location_views(shot_plans, results)
            for ep_id, plan in shot_plans.items():
                plan_file = plans_dir / f"ep_{ep_id.lower().replace('ep', '')}_plan.json"
                plan_file.write_text(
                    json.dumps(plan, indent=2, ensure_ascii=False),
                    encoding="utf-8",
                )
            print(f"Static binding: {updated} shots updated in plan JSON")

        total_cost = sum(r["cost"] for r in results)
        print(f"\nDone: {len(results)} shot-aware refs, ${total_cost:.3f}")

    else:
        # Legacy mode
        results = generate_location_refs(
            bible, project=args.project, dry_run=args.dry_run, wide_only=args.wide_only
        )
        total_cost = sum(r["cost"] for r in results)
        angles = len(set(r.get("angle", "wide") for r in results))
        print(f"\nDone: {len(results)} refs ({angles} angle types), ${total_cost:.3f}")


if __name__ == "__main__":
    main()
