#!/usr/bin/env python3
"""
probe_seedream_edit.py — Probe Seedream v4.5 edit endpoint with multi-ref input.

Same fal_client.subscribe pattern as recoil/tools/generate_turnarounds.py.
Uploads each reference image, hits fal-ai/bytedance/seedream/v4.5/edit with
the prompt + image_urls, downloads and saves the result alongside other
keyframe probes.

Usage:
    python3 recoil/pipeline/tools/probe_seedream_edit.py \\
      --image /path/to/look.jpg \\
      --image /path/to/pose.jpg \\
      --prompt "..." \\
      [--aspect landscape_16_9] [--project _probes] [--shot-id NAME] [--dry-run]
"""

import argparse
import os
import sys
import time
from io import BytesIO
from pathlib import Path

_HERE = os.path.dirname(os.path.abspath(__file__))
_RECOIL_ROOT = os.path.dirname(os.path.dirname(_HERE))
_PROJECTS_ROOT_PARENT = os.path.dirname(_RECOIL_ROOT)
if _RECOIL_ROOT not in sys.path:
    sys.path.insert(0, _RECOIL_ROOT)
if _PROJECTS_ROOT_PARENT not in sys.path:
    sys.path.insert(0, _PROJECTS_ROOT_PARENT)

from recoil.core.paths import projects_root  # noqa: E402


VALID_SIZES = {
    "landscape_16_9",
    "portrait_9_16",
    "square_1_1",
    "landscape_4_3",
    "portrait_3_4",
}


def main() -> int:
    parser = argparse.ArgumentParser(
        description="Probe Seedream v4.5 edit endpoint via fal_client."
    )
    parser.add_argument(
        "--image",
        action="append",
        required=True,
        help="Reference image path. Repeat to pass multiple (max 10).",
    )
    parser.add_argument("--prompt", required=True, help="Edit prompt text")
    parser.add_argument(
        "--aspect",
        default="landscape_16_9",
        choices=sorted(VALID_SIZES),
        help="Output image_size (default: landscape_16_9)",
    )
    parser.add_argument(
        "--num-images", type=int, default=1, help="Number of variants to request"
    )
    parser.add_argument("--project", required=True, help="Output project")
    parser.add_argument("--shot-id", default=None)
    parser.add_argument("--dry-run", action="store_true")
    args = parser.parse_args()

    image_paths: list[Path] = []
    for raw in args.image:
        ip = Path(raw).expanduser().resolve()
        if not ip.exists():
            print(f"ERROR: --image not found: {ip}")
            return 1
        image_paths.append(ip)

    if len(image_paths) > 10:
        print(f"ERROR: Seedream edit accepts max 10 image refs, got {len(image_paths)}.")
        return 1

    shot_id = args.shot_id or f"SEEDREAM_EDIT_{int(time.time())}"

    print()
    print("Mode:        Seedream v4.5 edit (probe)")
    print(f"Project:     {args.project}")
    print(f"Shot ID:     {shot_id}")
    print(f"Aspect:      {args.aspect}")
    print(f"Image refs:  {len(image_paths)}")
    for i, ip in enumerate(image_paths, 1):
        print(f"  @Image{i} = {ip.name}  ({ip.stat().st_size / 1024:.1f} KB)")
    print(f"Num images: {args.num_images}")
    print(f"Prompt ({len(args.prompt.split())} words):")
    print("  " + args.prompt.replace("\n", "\n  "))

    if args.dry_run:
        print()
        print("=== DRY RUN — not submitting ===")
        return 0

    import fal_client  # noqa: E402
    import requests  # noqa: E402

    print()
    print("Uploading refs...")
    image_urls = []
    for ip in image_paths:
        url = fal_client.upload_file(str(ip))
        image_urls.append(url)
        print(f"  {ip.name} -> {url}")

    print("Submitting...")
    t0 = time.time()
    result = fal_client.subscribe(
        "fal-ai/bytedance/seedream/v4.5/edit",
        arguments={
            "prompt": args.prompt,
            "image_urls": image_urls,
            "num_images": args.num_images,
            "image_size": args.aspect,
            "enable_safety_checker": False,
        },
    )
    elapsed = time.time() - t0

    images = result.get("images", [])
    if not images:
        print(f"ERROR: Seedream returned no images. Result: {result}")
        return 1

    out_dir = projects_root() / args.project / "output" / "frames" / "ep_001"
    out_dir.mkdir(parents=True, exist_ok=True)

    saved: list[Path] = []
    for i, img in enumerate(images, 1):
        img_url = img["url"]
        resp = requests.get(img_url, timeout=120)
        resp.raise_for_status()
        suffix = f"_take{i}" if len(images) > 1 else "_take1"
        out_path = out_dir / f"{shot_id}{suffix}.png"
        out_path.write_bytes(resp.content)
        saved.append(out_path)

    print()
    for p in saved:
        print(f"[OK] {shot_id} -> {p}")
    print(f"Done in {elapsed:.0f}s ({len(saved)} image{'s' if len(saved) != 1 else ''})")
    return 0


if __name__ == "__main__":
    sys.exit(main())
