#!/usr/bin/env python3
"""
restyle.py — Image-to-image restyle / edit dispatch.

THE single canonical entry point for "take an existing image, apply a new prompt,
get a re-rendered output" operations. Wraps Seedream v4.5 edit + NBP edit
(Gemini 3 Pro Image) behind one CLI surface.

If you need to add a new image-edit model in the future:
  - Add the model entry to recoil/config/model_profiles.json (canonical source)
  - If a new api_pattern, add a dispatch branch in main() below
  - Do NOT create a parallel restyle script — extend this one.

Source-of-truth principles enforced:
  - Model IDs / endpoints / costs come from model_profiles.json (NOT hardcoded)
  - Provider routing comes from the profile's `provider` + `api_pattern` fields
  - Sidecar JSON schema matches the existing manual_drop / generate_turnarounds.py
    convention (schema_version: 1, status: candidate, provenance: {...})
  - Aliases live ONLY in this file (clearly marked); future aliases either
    extend the dict here OR get promoted to a top-level `aliases` field in
    model_profiles.json (cleaner long-term, deferred for now)

Usage:
    # Seedream v4.5 edit
    python3 tools/restyle.py --model seedream-v4.5 \\
        --input <input.jpg> \\
        --prompt "Anime restyle in cyberpunk neon style..." \\
        --output <output.jpg>

    # NBP (Nanobanana Pro / Gemini 3 Pro Image) edit
    python3 tools/restyle.py --model nbp \\
        --input <input.jpg> \\
        --prompt "Anime restyle..." \\
        --output <output.jpg>

Outputs:
    <output>          — restyled image
    <output>.json     — sidecar provenance (prompt, model, input hash, cost, etc.)
"""

import argparse
import hashlib
import json
import logging
import os
import sys
import time
from pathlib import Path

# Add recoil/ to sys.path so we can be invoked as a module or script.
RECOIL_ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(RECOIL_ROOT))

logger = logging.getLogger("restyle")

CONFIG_PATH = RECOIL_ROOT / "config" / "model_profiles.json"

# Aliases — human-friendly names → canonical model_profiles.json keys.
# This is the ONLY place restyle.py resolves aliases. If a third place
# needs the same mapping, promote this dict into model_profiles.json
# under a top-level `aliases` field instead of duplicating it.
MODEL_ALIASES = {
    "nbp": "gemini-3-pro-image-preview",
    "nanobanana": "gemini-3-pro-image-preview",
    "nanobanana-pro": "gemini-3-pro-image-preview",
    "flash": "gemini-2.5-flash-image",
    "nanobanana-flash": "gemini-2.5-flash-image",
    # Direct canonical keys (seedream-v4.5, seedream-v5-lite,
    # gemini-3-pro-image-preview, gemini-2.5-flash-image) pass through.
}


# ─────────────────────────────────────────────────────────────────────────
# Model resolution
# ─────────────────────────────────────────────────────────────────────────


def load_model_profile(model_id: str) -> tuple[str, dict]:
    """Resolve alias and return (canonical_id, profile_dict).

    Single source of truth: recoil/config/model_profiles.json.
    """
    canonical = MODEL_ALIASES.get(model_id, model_id)
    profiles = json.loads(CONFIG_PATH.read_text())
    if canonical not in profiles:
        available = sorted(k for k in profiles if isinstance(profiles[k], dict)
                           and profiles[k].get("modality") == "image")
        raise ValueError(
            f"Model '{model_id}' (resolved to '{canonical}') not in model_profiles.json.\n"
            f"Available image models: {', '.join(available)}\n"
            f"Aliases: {', '.join(MODEL_ALIASES.keys())}"
        )
    return canonical, profiles[canonical]


def _mime_for(p: Path) -> str:
    suf = p.suffix.lower()
    return {
        ".png": "image/png",
        ".jpg": "image/jpeg",
        ".jpeg": "image/jpeg",
        ".webp": "image/webp",
    }.get(suf, "image/jpeg")


# ─────────────────────────────────────────────────────────────────────────
# Provider dispatchers — one per api_pattern
# ─────────────────────────────────────────────────────────────────────────


def restyle_via_seedream(
    canonical_id: str, profile: dict, input_path: Path, prompt: str,
    aspect_ratio: str, image_size: str, ref_paths: list[Path] | None = None,
) -> bytes:
    """Seedream v4.5 / v5-lite edit via fal_client.subscribe.

    Pattern matches recoil/tools/generate_turnarounds.py — fal_client is the
    documented dispatch path for Seedream image-edit operations.

    `ref_paths` (optional): additional reference images. Seedream's image_urls
    accepts up to max_reference_images per model_profiles (10 for v4.5).
    Convention: input_path is primary (first), ref_paths follow.
    """
    import fal_client

    base = profile.get("base_endpoint")
    if not base:
        raise ValueError(f"Profile for {canonical_id} missing base_endpoint")
    endpoint = base.rstrip("/") + "/edit"

    image_urls = [fal_client.upload_file(str(input_path))]
    logger.info("uploaded primary %s → %s", input_path.name, image_urls[0][:80])
    for rp in (ref_paths or []):
        u = fal_client.upload_file(str(rp))
        image_urls.append(u)
        logger.info("uploaded ref %s → %s", rp.name, u[:80])

    max_refs = profile.get("max_reference_images", 10)
    if len(image_urls) > max_refs:
        raise ValueError(
            f"Total images {len(image_urls)} exceeds {canonical_id} max_reference_images={max_refs}"
        )

    args = {
        "prompt": prompt,
        "image_urls": image_urls,
        "aspect_ratio": aspect_ratio,
        "image_size": image_size,
    }
    logger.info("dispatching %s with %d image(s)", endpoint, len(image_urls))
    result = fal_client.subscribe(endpoint, arguments=args)

    images = result.get("images") or []
    if not images:
        raise RuntimeError(f"Seedream returned no images. Raw: {result}")
    out_url = images[0].get("url") if isinstance(images[0], dict) else images[0]

    import urllib.request
    return urllib.request.urlopen(out_url, timeout=120).read()


def restyle_via_google(
    canonical_id: str, profile: dict, input_path: Path, prompt: str,
    aspect_ratio: str, image_size: str, ref_paths: list[Path] | None = None,
) -> bytes:
    """NBP / Gemini image-edit via google-genai SDK.

    Multimodal generate_content call: contents = [prompt, input_part, *ref_parts].
    Matches the pattern in execution/providers/google.py:direct_submit_image
    but for the edit (input image present) case.

    `ref_paths` (optional): additional identity/style references appended after
    the primary input. Up to max_reference_images per model_profiles (11 for NBP).
    """
    from google import genai
    from google.genai import types as genai_types

    api_key = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY")
    if not api_key:
        raise RuntimeError("GEMINI_API_KEY not set in env")

    client = genai.Client(api_key=api_key)

    contents = [
        prompt,
        genai_types.Part.from_bytes(
            data=input_path.read_bytes(), mime_type=_mime_for(input_path),
        ),
    ]
    for rp in (ref_paths or []):
        contents.append(genai_types.Part.from_bytes(
            data=rp.read_bytes(), mime_type=_mime_for(rp),
        ))

    max_refs = profile.get("max_reference_images", 11)
    n_images = len(contents) - 1  # subtract the prompt
    if n_images > max_refs:
        raise ValueError(
            f"Total images {n_images} exceeds {canonical_id} max_reference_images={max_refs}"
        )
    config = genai_types.GenerateContentConfig(
        response_modalities=["Image"],
        image_config=genai_types.ImageConfig(aspect_ratio=aspect_ratio),
    )
    logger.info("dispatching %s (genai_inline) ar=%s", canonical_id, aspect_ratio)
    response = client.models.generate_content(
        model=canonical_id, contents=contents, config=config,
    )

    for cand in response.candidates or []:
        for part in cand.content.parts or []:
            inline = getattr(part, "inline_data", None)
            if inline and getattr(inline, "data", None):
                return inline.data
    raise RuntimeError(f"NBP returned no inline image bytes. Response: {response}")


# ─────────────────────────────────────────────────────────────────────────
# Output writing — sidecar schema matches existing convention
# ─────────────────────────────────────────────────────────────────────────


def write_output(
    image_bytes: bytes, output_path: Path,
    user_alias: str, canonical_id: str, profile: dict,
    input_path: Path, prompt: str, aspect_ratio: str, image_size: str,
    ref_paths: list[Path] | None = None,
) -> None:
    output_path.parent.mkdir(parents=True, exist_ok=True)
    output_path.write_bytes(image_bytes)

    now = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
    input_hash = hashlib.sha256(input_path.read_bytes()).hexdigest()[:16]

    sidecar = {
        "schema_version": 1,
        "source": "restyle_tool",
        "status": "candidate",
        "created_at": now,
        "updated_at": now,
        "provenance": {
            "tool": "recoil/tools/restyle.py",
            "model": canonical_id,
            "model_alias": user_alias,
            "provider": profile.get("provider"),
            "endpoint": profile.get("base_endpoint"),
            "api_pattern": profile.get("api_pattern"),
            "prompt": prompt,
            "input_image": str(input_path),
            "input_sha256_16": input_hash,
            "ref_images": [str(p) for p in (ref_paths or [])],
            "aspect_ratio": aspect_ratio,
            "image_size": image_size,
            "cost_usd_estimated": profile.get("cost_per_image"),
        },
        "lineage": {
            "derived_from": str(input_path),
            "refs": [str(p) for p in (ref_paths or [])],
        },
        "notes": "",
        "tags": ["restyle"],
    }
    sidecar_path = output_path.with_suffix(output_path.suffix + ".json")
    sidecar_path.write_text(json.dumps(sidecar, indent=2))


# ─────────────────────────────────────────────────────────────────────────
# CLI
# ─────────────────────────────────────────────────────────────────────────


def main() -> int:
    logging.basicConfig(level=logging.INFO, format="%(message)s")

    parser = argparse.ArgumentParser(
        description=(
            "Image restyle / edit dispatch. Single canonical entry point for "
            "Seedream and NBP image-to-image operations."
        ),
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog=__doc__,
    )
    parser.add_argument(
        "--model", required=True,
        help=(
            "Model id or alias. Aliases: nbp, nanobanana, nanobanana-pro, flash. "
            "Canonical: seedream-v4.5, seedream-v5-lite, gemini-3-pro-image-preview, "
            "gemini-2.5-flash-image."
        ),
    )
    parser.add_argument("--input", required=True, type=Path,
                        help="Path to input image (png/jpg/jpeg/webp). Primary edit target.")
    parser.add_argument("--ref-image", action="append", type=Path, default=[],
                        dest="ref_images",
                        help="Additional reference image(s) for identity / style pinning. "
                             "Repeatable. Sent alongside --input. Max constrained by model "
                             "(see max_reference_images in model_profiles.json).")
    parser.add_argument("--prompt", required=True,
                        help="Restyle / edit prompt")
    parser.add_argument("--output", required=True, type=Path,
                        help="Output image path. Sidecar JSON saved as <output>.json")
    parser.add_argument(
        "--aspect-ratio", default="9:16",
        help="Output aspect ratio. Constrained per model (see model_profiles.json).",
    )
    parser.add_argument(
        "--size", default="2K",
        help="Output size. Constrained per model (see model_profiles.json supported_sizes).",
    )
    parser.add_argument(
        "--open-when-done", action="store_true",
        help="`open` the output image after writing (macOS).",
    )
    args = parser.parse_args()

    if not args.input.exists():
        print(f"ERROR: input not found: {args.input}", file=sys.stderr)
        return 1

    try:
        canonical_id, profile = load_model_profile(args.model)
    except ValueError as e:
        print(f"ERROR: {e}", file=sys.stderr)
        return 1

    provider = profile.get("provider")
    api_pattern = profile.get("api_pattern")

    print(f"Model:    {args.model} → {canonical_id} ({profile.get('display_name')})")
    print(f"Provider: {provider}  ·  api_pattern: {api_pattern}")
    print(f"Input:    {args.input}")
    print(f"Output:   {args.output}")
    print(f"Aspect:   {args.aspect_ratio}  ·  Size: {args.size}")
    print(f"Prompt ({len(args.prompt.split())} words):")
    print("  " + args.prompt[:300] + ("..." if len(args.prompt) > 300 else ""))
    print()

    t0 = time.time()
    for rp in args.ref_images:
        if not rp.exists():
            print(f"ERROR: --ref-image not found: {rp}", file=sys.stderr)
            return 1

    if args.ref_images:
        print(f"Refs:     {len(args.ref_images)} identity/style image(s)")
        for rp in args.ref_images:
            print(f"          {rp}")

    if api_pattern == "fal_ai_seedream":
        image_bytes = restyle_via_seedream(
            canonical_id, profile, args.input, args.prompt,
            args.aspect_ratio, args.size, ref_paths=args.ref_images,
        )
    elif provider == "google" and profile.get("modality") == "image":
        image_bytes = restyle_via_google(
            canonical_id, profile, args.input, args.prompt,
            args.aspect_ratio, args.size, ref_paths=args.ref_images,
        )
    else:
        print(
            f"ERROR: no restyle path for provider={provider}, "
            f"api_pattern={api_pattern}. To add support, extend this file's "
            f"main() with a new branch — do NOT create a parallel script.",
            file=sys.stderr,
        )
        return 1

    write_output(
        image_bytes, args.output,
        args.model, canonical_id, profile,
        args.input, args.prompt, args.aspect_ratio, args.size,
        ref_paths=args.ref_images,
    )

    elapsed = time.time() - t0
    print(f"OK: wrote {args.output} ({len(image_bytes):,} bytes) in {elapsed:.1f}s")
    print(f"    sidecar: {args.output}.json")
    print(f"    est cost: ${profile.get('cost_per_image', 0):.3f}")

    if args.open_when_done:
        os.system(f'open "{args.output}"')

    return 0


if __name__ == "__main__":
    sys.exit(main())
