#!/usr/bin/env python3
"""
upscale_gemini.py — Gemini NanobananaPro Image Upscale Utility

Upscales images via Gemini 2.5 Flash Image generation. Can crop border
artifacts (e.g., from triptych panel splitting) before upscaling.

Usage (standalone CLI):
    python3 upscale_gemini.py image.jpg                      # → image_hq.png
    python3 upscale_gemini.py image.jpg -o output.png        # → specific path
    python3 upscale_gemini.py dir/ --glob "*.jpg"            # → batch → dir/upscaled/
    python3 upscale_gemini.py dir/ --glob "*.jpg" --crop 4   # → crop 4px then upscale

Usage (importable):
    from upscale_gemini import upscale_image, crop_and_upscale, crop_image
    result = upscale_image("panel.jpg", output_path="panel_hq.png")
    result = crop_and_upscale("panel.jpg", crop_px=4)

    # With cost tracking:
    from cost_tracker import CostTracker
    tracker = CostTracker("leviathan/")
    result = upscale_image("panel.jpg", tracker=tracker)

Env vars:
    GOOGLE_API_KEY — Gemini API key (required)

Dependencies:
    pip install google-genai Pillow
"""

import argparse
import io
import os
import sys
import time
from pathlib import Path
from typing import Optional

from cost_tracker import CostTracker
from recoil.core.model_profiles import get_model

UPSCALE_PROMPT = (
    "Upscale this image to the highest resolution possible. "
    "KEEP EVERYTHING EXACTLY THE SAME. Do not change any details, colors, lighting, "
    "composition, facial features, or expression. Do not add or remove anything. "
    "Only increase resolution and sharpen details. Maintain exact same framing and aspect ratio."
)

DEFAULT_MODEL = get_model("upscale", "image")
DEFAULT_DELAY = 5.0  # seconds between API calls (15 RPM limit)


def crop_image(path: str, px: int = 4, output_path: Optional[str] = None) -> str:
    """Remove px pixels from each edge of an image (border artifact removal).

    Args:
        path: Input image path.
        px: Pixels to crop from each edge.
        output_path: Where to save. Defaults to {stem}_cropped{suffix}.

    Returns:
        Path to the cropped image.
    """
    from PIL import Image

    img = Image.open(path)
    w, h = img.size
    cropped = img.crop((px, px, w - px, h - px))

    if output_path is None:
        p = Path(path)
        output_path = str(p.with_stem(p.stem + "_cropped"))

    cropped.save(output_path)
    return output_path


def upscale_image(
    path: str,
    output_path: Optional[str] = None,
    model: str = DEFAULT_MODEL,
    client=None,
    tracker: Optional[CostTracker] = None,
) -> str:
    """Upscale a single image via Gemini NBP.

    Args:
        path: Input image path.
        output_path: Where to save. Defaults to {stem}_hq.png.
        model: Gemini model name.
        client: Existing genai.Client (created if None).
        tracker: CostTracker instance for cost logging (optional).

    Returns:
        Path to the upscaled image.
    """
    from PIL import Image

    if client is None:
        from google import genai
        api_key = os.environ.get("GOOGLE_API_KEY")
        if not api_key:
            raise RuntimeError("GOOGLE_API_KEY not set")
        client = genai.Client(api_key=api_key)

    from google.genai import types

    img = Image.open(path)
    buf = io.BytesIO()
    fmt = "PNG" if Path(path).suffix.lower() == ".png" else "JPEG"
    img.save(buf, format=fmt, quality=95)
    image_bytes = buf.getvalue()
    mime = "image/png" if fmt == "PNG" else "image/jpeg"

    input_filename = Path(path).name
    t0 = time.time()

    try:
        response = client.models.generate_content(
            model=model,
            contents=[
                UPSCALE_PROMPT,
                types.Part.from_bytes(data=image_bytes, mime_type=mime),
            ],
            config=types.GenerateContentConfig(
                response_modalities=["IMAGE", "TEXT"],
            ),
        )
    except Exception as e:
        elapsed_ms = int((time.time() - t0) * 1000)
        if tracker:
            tracker.log(
                category="upscale",
                provider="gemini",
                model=model,
                images_in=0,
                images_out=0,
                tokens_in=0,
                tokens_out=0,
                duration_ms=elapsed_ms,
                detail=f"Gemini upscale: {input_filename} — {str(e)[:100]}",
                success=False,
            )
        raise

    elapsed_ms = int((time.time() - t0) * 1000)

    # Extract token usage from response metadata
    tokens_in = 0
    tokens_out = 0
    if hasattr(response, 'usage_metadata') and response.usage_metadata:
        tokens_in = getattr(response.usage_metadata, 'prompt_token_count', 0) or 0
        tokens_out = getattr(response.usage_metadata, 'candidates_token_count', 0) or 0

    # Extract image from response
    if response.candidates:
        for part in response.candidates[0].content.parts:
            if part.inline_data and part.inline_data.mime_type.startswith("image/"):
                if output_path is None:
                    p = Path(path)
                    output_path = str(p.with_stem(p.stem + "_hq").with_suffix(".png"))
                Path(output_path).parent.mkdir(parents=True, exist_ok=True)
                with open(output_path, "wb") as f:
                    f.write(part.inline_data.data)
                if tracker:
                    tracker.log(
                        category="upscale",
                        provider="gemini",
                        model=model,
                        images_out=1,
                        tokens_in=tokens_in,
                        tokens_out=tokens_out,
                        duration_ms=elapsed_ms,
                        detail=f"Gemini upscale: {input_filename}",
                        success=True,
                    )
                return output_path

    # No image — check for text response
    text_parts = []
    if response.candidates:
        for part in response.candidates[0].content.parts:
            if hasattr(part, "text") and part.text:
                text_parts.append(part.text)

    error = " ".join(text_parts)[:200] if text_parts else "No image in response"
    if tracker:
        tracker.log(
            category="upscale",
            provider="gemini",
            model=model,
            images_out=0,
            tokens_in=tokens_in,
            tokens_out=tokens_out,
            duration_ms=elapsed_ms,
            detail=f"Gemini upscale: {input_filename} — {error}",
            success=False,
        )
    raise RuntimeError(f"Upscale failed: {error}")


def crop_and_upscale(
    path: str,
    crop_px: int = 4,
    output_path: Optional[str] = None,
    model: str = DEFAULT_MODEL,
    client=None,
    tracker: Optional[CostTracker] = None,
) -> str:
    """Crop border artifacts then upscale via Gemini NBP.

    Args:
        path: Input image path.
        crop_px: Pixels to remove from each edge.
        output_path: Final output path. Defaults to {stem}_hq.png.
        model: Gemini model name.
        client: Existing genai.Client.
        tracker: CostTracker instance for cost logging (optional).

    Returns:
        Path to the upscaled image.
    """
    # Crop to temp, then upscale
    p = Path(path)
    cropped_path = str(p.with_stem(p.stem + "_cropped"))
    crop_image(path, px=crop_px, output_path=cropped_path)

    result = upscale_image(cropped_path, output_path=output_path, model=model, client=client, tracker=tracker)

    # Clean up temp cropped file
    try:
        os.remove(cropped_path)
    except OSError:
        pass

    return result


def main():
    parser = argparse.ArgumentParser(description="Upscale images via Gemini NBP")
    parser.add_argument("input", help="Image file or directory")
    parser.add_argument("-o", "--output", help="Output path (single file mode)")
    parser.add_argument("--glob", help="Glob pattern for batch mode (e.g., '*.jpg')")
    parser.add_argument("--crop", type=int, default=0, help="Crop N pixels from each edge before upscaling")
    parser.add_argument("--model", default=DEFAULT_MODEL, help=f"Gemini model (default: {DEFAULT_MODEL})")
    parser.add_argument("--delay", type=float, default=DEFAULT_DELAY, help=f"Seconds between API calls (default: {DEFAULT_DELAY})")
    parser.add_argument("--dry-run", action="store_true", help="Show what would be processed")
    args = parser.parse_args()

    from google import genai
    from PIL import Image

    input_path = Path(args.input)

    # Collect files
    if input_path.is_dir():
        if not args.glob:
            print("ERROR: --glob required for directory input", file=sys.stderr)
            sys.exit(1)
        files = sorted(input_path.glob(args.glob))
        output_dir = input_path / "upscaled"
    elif input_path.is_file():
        files = [input_path]
        output_dir = input_path.parent
    else:
        print(f"ERROR: Not found: {args.input}", file=sys.stderr)
        sys.exit(1)

    if not files:
        print(f"ERROR: No files matching '{args.glob}' in {input_path}", file=sys.stderr)
        sys.exit(1)

    print(f"Gemini NBP Upscale — {len(files)} file(s)")
    print(f"Model: {args.model} | Crop: {args.crop}px | Delay: {args.delay}s")
    print()

    if args.dry_run:
        for f in files:
            img = Image.open(f)
            print(f"  {f.name}: {img.size[0]}x{img.size[1]}")
        print(f"\nWould output to: {output_dir}/")
        sys.exit(0)

    api_key = os.environ.get("GOOGLE_API_KEY")
    if not api_key:
        print("ERROR: GOOGLE_API_KEY not set", file=sys.stderr)
        sys.exit(1)

    client = genai.Client(api_key=api_key)
    output_dir.mkdir(parents=True, exist_ok=True)

    # Initialize cost tracker — look for project path from input location
    # Walk up from input to find a directory containing visual/ (project root)
    tracker = None
    probe = input_path if input_path.is_dir() else input_path.parent
    for _ in range(6):
        if (probe / "visual").is_dir():
            tracker = CostTracker(str(probe))
            break
        probe = probe.parent

    success = 0
    failed = 0

    for i, f in enumerate(files):
        img = Image.open(f)
        print(f"  [{i+1}/{len(files)}] {f.name}: {img.size[0]}x{img.size[1]}")

        if len(files) == 1 and args.output:
            out = args.output
        else:
            out = str(output_dir / f"{f.stem}_hq.png")

        try:
            t0 = time.time()
            if args.crop > 0:
                result = crop_and_upscale(str(f), crop_px=args.crop, output_path=out, model=args.model, client=client, tracker=tracker)
            else:
                result = upscale_image(str(f), output_path=out, model=args.model, client=client, tracker=tracker)
            elapsed = time.time() - t0

            out_img = Image.open(result)
            print(f"           -> {out_img.size[0]}x{out_img.size[1]} in {elapsed:.1f}s: {result}")
            success += 1
        except Exception as e:
            print(f"           FAILED: {e}")
            failed += 1

        if i < len(files) - 1:
            time.sleep(args.delay)

    print(f"\nDone: {success} upscaled, {failed} failed")
    sys.exit(0 if failed == 0 else 1)


if __name__ == "__main__":
    main()
