#!/usr/bin/env python3
"""
prep_expressions.py — Generate universal grayscale expression matrix.

Generates 3 expression grids (3 emotions × 3 intensities each = 27 refs)
using Flash 3.1. Universal library — run once per show, stored in
assets/expressions/, shared across all characters.

ADR-C05: Universal generic actor (bald, androgynous) prevents identity
over-baking when expression refs are combined with character identity refs.

Usage:
    python -m tools.prep_expressions                  # Generate all 3 grids
    python -m tools.prep_expressions --dry-run         # Preview without API calls
    python -m tools.prep_expressions --set 0           # Generate only first set
"""

import argparse
import json
import logging
import os
import sys
from pathlib import Path
from typing import Optional

sys.path.insert(0, str(Path(__file__).parent.parent))

from recoil.core.paths import PIPELINE_ROOT
from recoil.core.model_profiles import get_model
from recoil.core.prompt_config import get_constant

OUTPUT_DIR = PIPELINE_ROOT / "assets" / "expressions"
FLASH_MODEL = get_model("exploration", "image")
FLASH_COST = 0.039

# 3 grids × 3 emotions each = 9 emotions × 3 intensities = 27 expression refs
EMOTION_SETS = [
    ("joy", "anger", "sorrow"),
    ("fear", "determination", "exhaustion"),
    ("surprise", "disgust", "neutral"),
]

INTENSITIES = ["subtle", "active", "extreme"]

logger = logging.getLogger("starsend.prep_expressions")


def generate_universal_expressions(
    emotion_set_index: Optional[int] = None,
    dry_run: bool = False,
) -> dict:
    """Generate universal grayscale expression matrix.

    Produces 3 grids (or a single grid if emotion_set_index specified),
    each with 3 emotions × 3 intensity levels. Output stored in
    assets/expressions/ as {emotion}_{intensity}.png.

    Args:
        emotion_set_index: Generate only this emotion set (0-2). None = all.
        dry_run: Preview without API calls.

    Returns:
        dict with keys: grids, expressions, cost, output_dir
    """
    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

    sets_to_generate = (
        [EMOTION_SETS[emotion_set_index]]
        if emotion_set_index is not None
        else EMOTION_SETS
    )
    set_indices = (
        [emotion_set_index]
        if emotion_set_index is not None
        else list(range(len(EMOTION_SETS)))
    )

    result = {
        "grids": [],
        "expressions": [],
        "cost": 0.0,
        "output_dir": str(OUTPUT_DIR),
    }

    if dry_run:
        logger.info("[DRY RUN] Would generate %d expression grids", len(sets_to_generate))
        for idx, emotion_set in zip(set_indices, sets_to_generate):
            grid_info = {
                "set_index": idx,
                "emotions": list(emotion_set),
                "grid_path": None,
                "dry_run": True,
            }
            result["grids"].append(grid_info)
            for emotion in emotion_set:
                for intensity in INTENSITIES:
                    result["expressions"].append({
                        "emotion": emotion,
                        "intensity": intensity,
                        "path": str(OUTPUT_DIR / f"{emotion}_{intensity}.png"),
                    })
        logger.info(
            "[DRY RUN] %d grids × 9 panels = %d expression refs planned",
            len(sets_to_generate),
            len(result["expressions"]),
        )
        return result

    from recoil.pipeline._lib.prompt_engine import build_universal_expression_matrix

    for idx, emotion_set in zip(set_indices, sets_to_generate):
        logger.info(
            "Generating expression grid %d/3: %s",
            idx + 1, ", ".join(emotion_set),
        )

        prompt = build_universal_expression_matrix(*emotion_set)
        grid_bytes = _generate_flash_grid(prompt)

        grid_info = {
            "set_index": idx,
            "emotions": list(emotion_set),
            "grid_path": None,
        }
        result["cost"] += FLASH_COST

        if grid_bytes is None:
            logger.error("Grid generation failed for set %d: %s", idx, emotion_set)
            result["grids"].append(grid_info)
            continue

        grid_path = OUTPUT_DIR / f"expression_grid_{idx}.png"
        grid_path.write_bytes(grid_bytes)
        grid_info["grid_path"] = str(grid_path)
        result["grids"].append(grid_info)
        logger.info("Grid saved: %s", grid_path)

        # Split grid into individual expression panels
        expressions = _split_expression_grid(grid_path, emotion_set)
        result["expressions"].extend(expressions)

    logger.info(
        "Expression matrix complete: %d grids, %d refs, cost $%.3f",
        len(result["grids"]),
        len(result["expressions"]),
        result["cost"],
    )
    return result


def _generate_flash_grid(prompt: str) -> Optional[bytes]:
    """Generate expression grid via Flash 3.1 (text-only, no hero image)."""
    try:
        from google import genai
        from google.genai import types
    except ImportError:
        logger.error("google-genai SDK not installed")
        return None

    api_key = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY")
    if not api_key:
        logger.error("GEMINI_API_KEY not set")
        return None

    client = genai.Client(api_key=api_key)

    try:
        response = client.models.generate_content(
            model=FLASH_MODEL,
            contents=prompt,
            config=types.GenerateContentConfig(
                temperature=0.2,
                response_modalities=["IMAGE", "TEXT"],
                image_config=types.ImageConfig(
                    aspect_ratio="1:1",
                ),
            ),
        )
    except Exception as e:
        logger.error("Flash API error: %s", e)
        return None

    if response and response.candidates:
        for candidate in response.candidates:
            if candidate.content and candidate.content.parts:
                for part in candidate.content.parts:
                    if hasattr(part, "inline_data") and part.inline_data:
                        return part.inline_data.data

    return None


def _split_expression_grid(
    grid_path: Path,
    emotions: tuple,
) -> list:
    """Split 3x3 grid into individual expression images.

    Grid layout (ADR-C05):
        Columns: emotion_1, emotion_2, emotion_3
        Rows: subtle (top), active (middle), extreme (bottom)

    Output naming: {emotion}_{intensity}.png
    All panels converted to grayscale.
    """
    try:
        from PIL import Image
    except ImportError:
        logger.warning("PIL not available — cannot split grid")
        return []

    if not grid_path.exists():
        return []

    img = Image.open(grid_path)

    # Detect and crop label/border regions before splitting
    from tools.prep_character_angles import _detect_content_bounds
    top, bottom, left, right = _detect_content_bounds(img)
    content = img.crop((left, top, right, bottom))
    cw, ch = content.size
    panel_w = cw // 3
    panel_h = ch // 3

    expressions = []
    for col_idx, emotion in enumerate(emotions):
        for row_idx, intensity in enumerate(INTENSITIES):
            box = (
                col_idx * panel_w,
                row_idx * panel_h,
                (col_idx + 1) * panel_w,
                (row_idx + 1) * panel_h,
            )
            panel = content.crop(box)
            # Convert to grayscale (ADR-C05: prevents identity bleed)
            panel = panel.convert("L")

            panel_path = OUTPUT_DIR / f"{emotion}_{intensity}.png"
            panel.save(panel_path)

            expressions.append({
                "emotion": emotion,
                "intensity": intensity,
                "path": str(panel_path),
            })

    logger.info("Split grid into %d expressions", len(expressions))
    return expressions


def main():
    parser = argparse.ArgumentParser(
        description="Generate universal grayscale expression matrix (ADR-C05)"
    )
    parser.add_argument(
        "--set", type=int, default=None, choices=[0, 1, 2],
        help="Generate only this emotion set (0=joy/anger/sorrow, 1=fear/determination/exhaustion, 2=surprise/disgust/neutral)",
    )
    parser.add_argument("--dry-run", action="store_true", help="Preview without API calls")

    args = parser.parse_args()

    logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")

    result = generate_universal_expressions(
        emotion_set_index=args.set,
        dry_run=args.dry_run,
    )

    print(json.dumps(result, indent=2, default=str))


if __name__ == "__main__":
    main()
