#!/usr/bin/env python3
"""
generate_location_refs.py — Generate missing location reference images via NBP.

The #1 data gap: 72 locations defined in breakdown.json, 0 reference images.

Usage:
    python -m tools.generate_location_refs --episode 1 --dry-run
    python -m tools.generate_location_refs --episode 1
    python -m tools.generate_location_refs --all --dry-run
"""

import argparse
import logging
import os
import sys
import time
from pathlib import Path
from typing import Optional

sys.path.insert(0, str(Path(__file__).parent.parent))

from recoil.pipeline._lib.recoil_bridge import load_breakdown, load_project_config
from recoil.pipeline._lib.prompt_engine import build_location_ref_prompt
from recoil.core.model_profiles import get_cost, get_model

# ── Optional imports ─────────────────────────────────────────────────

try:
    from google import genai
    from google.genai import types
    _HAS_GENAI = True
except ImportError:
    genai = None
    types = None
    _HAS_GENAI = False

# ── Constants ────────────────────────────────────────────────────────

from recoil.core.paths import ProjectPaths

logger = logging.getLogger("starsend.location_refs")


def _check_sterility(image_data: bytes, genai_client=None, slug: str = "") -> bool:
    """Check if a location ref image is sterile (no human figures).

    Uses Claude Haiku for fast, cheap vision classification (~$0.001).
    Falls back to Gemini Flash text-only if Anthropic SDK unavailable.

    Returns True if image is clean, False if contaminated.
    """
    import base64

    _STERILITY_PROMPT = "Does this image contain any human figures, bodies, silhouettes, or people? Answer only YES or NO."

    # Primary: Gemini Flash vision check (~$0.00 — text-only response on tiny image)
    # We already have a genai client from the generation step, so reuse it.
    if genai_client:
        try:
            resp = genai_client.models.generate_content(
                model="gemini-2.0-flash",
                contents=[
                    types.Part.from_bytes(data=image_data, mime_type="image/png"),
                    types.Part.from_text(text=_STERILITY_PROMPT),
                ],
                config=types.GenerateContentConfig(
                    response_modalities=["TEXT"],
                    temperature=0.0,
                ),
            )
            answer = resp.text.strip().upper() if resp.text else ""
            if "YES" in answer:
                logger.info(f"  Sterility check ({slug}): CONTAMINATED (figures detected)")
                return False
            logger.info(f"  Sterility check ({slug}): CLEAN")
            return True
        except Exception as e:
            logger.warning(f"  Gemini sterility check failed: {e}")

    # Fallback: Haiku (if Anthropic SDK + key available)
    try:
        import anthropic
        api_key = os.environ.get("ANTHROPIC_API_KEY")
        if not api_key:
            raise ValueError("ANTHROPIC_API_KEY not set")
        client = anthropic.Anthropic(api_key=api_key)
        b64 = base64.standard_b64encode(image_data).decode()
        resp = client.messages.create(
            model="claude-haiku-4-5-20251001",
            max_tokens=10,
            messages=[{
                "role": "user",
                "content": [
                    {"type": "image", "source": {"type": "base64", "media_type": "image/png", "data": b64}},
                    {"type": "text", "text": _STERILITY_PROMPT},
                ],
            }],
        )
        answer = resp.content[0].text.strip().upper()
        if "YES" in answer:
            logger.info(f"  Sterility check ({slug}): CONTAMINATED (Haiku detected figures)")
            return False
        logger.info(f"  Sterility check ({slug}): CLEAN")
        return True
    except Exception as e:
        logger.warning(f"  Haiku sterility check also failed: {e}")

    # If both fail, pass through with warning
    logger.warning(f"  Sterility check ({slug}): SKIPPED (no vision model available)")
    return True


from recoil.pipeline._lib.taxonomy import slugify_asset_id as _slugify


def _get_zone(location_data: dict) -> str:
    """Extract zone from location data (e.g. 'lower_decks', 'cryo_sector')."""
    return location_data.get("zone", "unknown")


def get_episode_locations(episode: int, project: str = None) -> list[str]:
    """Get location keys used in a specific episode."""
    bd = load_breakdown(project)
    locations = bd.get("locations", {})
    episode_locs = []
    for loc_key, loc_data in locations.items():
        episodes = loc_data.get("episodes", [])
        if episode in episodes:
            episode_locs.append(loc_key)
    return episode_locs


def get_all_locations(project: str = None) -> list[str]:
    """Get all location keys."""
    bd = load_breakdown(project)
    return list(bd.get("locations", {}).keys())


def location_has_ref(location_key: str, project: str = None) -> bool:
    """Check if a location already has reference images."""
    slug = _slugify(location_key)
    subject_root = ProjectPaths.for_project(project).asset_subject_dir("loc", slug)
    if subject_root.exists():
        images = [f for f in subject_root.iterdir()
                  if f.suffix.lower() in {".png", ".jpg", ".jpeg", ".webp"}]
        return len(images) > 0
    return False


def generate_location_ref(
    location_key: str,
    location_data: dict,
    config: dict,
    model: str = get_model("production", "image"),
    dry_run: bool = False,
    project: str = None,
) -> Optional[Path]:
    """Generate a single location reference image.

    Args:
        location_key: Full location key (e.g. "INT. LEVIATHAN - LOWER DECK SALVAGE CORRIDOR")
        location_data: Location dict from breakdown.json
        config: Project config dict
        model: Model ID to use
        dry_run: If True, print prompt but don't generate

    Returns:
        Path to saved image, or None on failure/dry_run.
    """
    # Build description from samples
    desc_samples = location_data.get("description_samples", [])
    location_desc = " ".join(desc_samples) if desc_samples else location_key
    lighting_notes = location_data.get("lighting_notes", [])

    # Build prompt
    prompt = build_location_ref_prompt(location_desc, lighting_notes, config)

    slug = _slugify(location_key)
    zone = _get_zone(location_data)

    if dry_run:
        print(f"\n  Location: {location_key}")
        print(f"  Slug: {slug}")
        print(f"  Zone: {zone}")
        print(f"  Model: {model} (${get_cost(model):.3f})")
        print(f"  Prompt ({len(prompt)} chars):")
        print(f"    {prompt[:200]}...")
        return None

    if not _HAS_GENAI:
        logger.error("google-genai SDK not installed. Cannot generate.")
        return None

    api_key = os.environ.get("GEMINI_API_KEY")
    if not api_key:
        logger.error("GEMINI_API_KEY not set.")
        return None

    client = genai.Client(api_key=api_key)

    # Generate
    try:
        response = client.models.generate_content(
            model=model,
            contents=[types.Part(text=prompt)],
            config=types.GenerateContentConfig(
                response_modalities=["IMAGE", "TEXT"],
                image_config=types.ImageConfig(
                    aspect_ratio="9:16",
                ),
            ),
        )
    except Exception as e:
        logger.error(f"  API error for {slug}: {e}")
        return None

    # Extract image
    image_data = None
    if response and response.candidates:
        for candidate in response.candidates:
            if candidate.content and candidate.content.parts:
                for part in candidate.content.parts:
                    if hasattr(part, "inline_data") and part.inline_data:
                        image_data = part.inline_data.data
                        break

    if not image_data:
        logger.warning(f"  No image returned for {slug}")
        return None

    # ── Sterility gate: reject location refs that contain human figures ──
    if not _check_sterility(image_data, client, slug):
        logger.warning(f"  REJECTED {slug}: sterility check failed (human figures detected)")
        # Retry once with stronger negative prompt
        logger.info(f"  Retrying {slug} with reinforced prompt...")
        retry_prompt = prompt + "\n\nCRITICAL: This image MUST be completely empty. Absolutely NO human figures, bodies, silhouettes, or any suggestion of people. Empty environment only."
        try:
            retry_resp = client.models.generate_content(
                model=model,
                contents=[types.Part(text=retry_prompt)],
                config=types.GenerateContentConfig(
                    response_modalities=["IMAGE", "TEXT"],
                    image_config=types.ImageConfig(aspect_ratio="9:16"),
                ),
            )
            image_data = None
            if retry_resp and retry_resp.candidates:
                for candidate in retry_resp.candidates:
                    if candidate.content and candidate.content.parts:
                        for part in candidate.content.parts:
                            if hasattr(part, "inline_data") and part.inline_data:
                                image_data = part.inline_data.data
                                break
            if image_data and _check_sterility(image_data, client, slug):
                logger.info(f"  Retry passed sterility check for {slug}")
            else:
                logger.error(f"  Retry also failed sterility for {slug} — skipping")
                return None
        except Exception as e:
            logger.error(f"  Retry failed for {slug}: {e}")
            return None

    # Save
    out_dir = ProjectPaths.for_project(project).asset_subject_dir("loc", slug)
    out_dir.mkdir(parents=True, exist_ok=True)
    out_path = out_dir / f"{slug}_ref.png"
    out_path.write_bytes(image_data)
    logger.info(f"  Saved: {out_path}")
    return out_path


def main():
    parser = argparse.ArgumentParser(description="Generate location reference images")
    parser.add_argument("--episode", type=int, help="Generate refs for locations in this episode")
    parser.add_argument("--all", action="store_true", help="Generate refs for all 72 locations")
    parser.add_argument("--dry-run", action="store_true", help="Show what would be generated")
    parser.add_argument("--model", default=get_model("production", "image"),
                        help="Model to use (default: gemini-3-pro-image-preview)")
    parser.add_argument("--skip-existing", action="store_true", default=True,
                        help="Skip locations that already have refs (default: True)")
    parser.add_argument("--project", default=None)
    args = parser.parse_args()

    logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s",
                        datefmt="%H:%M:%S")

    if not args.episode and not args.all:
        parser.error("Specify --episode N or --all")

    bd = load_breakdown(args.project)
    config = load_project_config(args.project)
    locations = bd.get("locations", {})

    # Filter locations
    if args.all:
        target_keys = list(locations.keys())
    else:
        target_keys = get_episode_locations(args.episode, args.project)

    if not target_keys:
        print(f"No locations found for episode {args.episode}")
        return

    # Filter out existing
    if args.skip_existing:
        target_keys = [k for k in target_keys if not location_has_ref(k, project=args.project)]

    cost_per = get_cost(args.model)
    total_cost = cost_per * len(target_keys)

    print("=== Location Reference Generator ===")
    print(f"Locations to generate: {len(target_keys)}")
    print(f"Model: {args.model} (${cost_per:.3f}/image)")
    print(f"Estimated cost: ${total_cost:.3f}")
    if args.dry_run:
        print("MODE: DRY RUN")

    generated = 0
    failed = 0

    for i, loc_key in enumerate(target_keys, 1):
        loc_data = locations[loc_key]
        print(f"\n[{i}/{len(target_keys)}] {loc_key}")

        result = generate_location_ref(
            loc_key, loc_data, config,
            model=args.model, dry_run=args.dry_run,
            project=args.project,
        )

        if result:
            generated += 1
        elif not args.dry_run:
            failed += 1

        # Rate limit: wait between calls
        if not args.dry_run and i < len(target_keys):
            time.sleep(2)

    if not args.dry_run:
        print("\n=== Done ===")
        print(f"Generated: {generated}")
        print(f"Failed: {failed}")
        print(f"Total cost: ~${generated * cost_per:.3f}")


if __name__ == "__main__":
    main()
