#!/usr/bin/env python3
"""
probe_nbp_pose_match.py — Probe NBP keyframe gen with style+pose dual reference.

Uses StepRunner.execute_keyframe() with a scene_ref (style/look target) and a
pose_ref (skeleton/stance source) so NBP regenerates the look image at the
reference pose. Sanctioned route — same path Plan-Pass keyframe generation uses.

Usage:
    python3 recoil/pipeline/tools/probe_nbp_pose_match.py \\
      --look /path/to/look_target.jpg \\
      --pose /path/to/pose_reference.jpg \\
      --prompt "Match the pose of the reference while keeping the 2D cell..." \\
      [--model nbp] [--aspect 16:9] [--project _probes] [--dry-run]
"""

import argparse
import os
import sys
import time
from pathlib import Path
from typing import Optional

_HERE = os.path.dirname(os.path.abspath(__file__))
_RECOIL_ROOT = os.path.dirname(os.path.dirname(_HERE))
_PROJECTS_ROOT_PARENT = os.path.dirname(_RECOIL_ROOT)
if _RECOIL_ROOT not in sys.path:
    sys.path.insert(0, _RECOIL_ROOT)
if _PROJECTS_ROOT_PARENT not in sys.path:
    sys.path.insert(0, _PROJECTS_ROOT_PARENT)

from recoil.core.paths import ensure_pipeline_importable  # noqa: E402

ensure_pipeline_importable()


def main() -> int:
    parser = argparse.ArgumentParser(
        description="Probe NBP pose-match (style ref + pose ref → new keyframe)."
    )
    parser.add_argument("--look", required=True, help="Style/look reference image")
    parser.add_argument("--pose", default=None, help="Pose reference image (optional)")
    parser.add_argument("--prompt", required=True, help="Generation prompt")
    parser.add_argument("--model", default="gemini-3.1-flash-image-preview", help="Image model id (default: Flash 3.1 — cheaper than NBP Pro at $0.04 vs $0.13)")
    parser.add_argument("--aspect", default="16:9", help="Aspect ratio (default: 16:9)")
    parser.add_argument("--project", required=True)
    parser.add_argument("--shot-id", default=None)
    parser.add_argument("--dry-run", action="store_true")
    args = parser.parse_args()

    look_path = Path(args.look).expanduser().resolve()
    if not look_path.exists():
        print(f"ERROR: --look not found: {look_path}")
        return 1
    pose_path: Optional[Path] = None
    if args.pose:
        pose_path = Path(args.pose).expanduser().resolve()
        if not pose_path.exists():
            print(f"ERROR: --pose not found: {pose_path}")
            return 1

    shot_id = args.shot_id or f"NBP_POSE_{int(time.time())}"

    print()
    print("Mode:        NBP pose-match (probe)")
    print(f"Project:     {args.project}")
    print(f"Shot ID:     {shot_id}")
    print(f"Model:       {args.model}")
    print(f"Aspect:      {args.aspect}")
    print(f"Look ref:    {look_path.name}")
    print(f"Pose ref:    {pose_path.name if pose_path else '(none — prompt-only pose)'}")
    print(f"Prompt ({len(args.prompt.split())} words):")
    print("  " + args.prompt.replace("\n", "\n  "))

    if args.dry_run:
        print()
        print("=== DRY RUN — not submitting ===")
        return 0

    from recoil.execution.execution_store import ExecutionStore  # noqa: E402
    from recoil.execution.step_runner import StepRunner  # noqa: E402
    from recoil.execution.step_types import ProjectPaths  # noqa: E402

    store = ExecutionStore(args.project)
    paths = ProjectPaths.for_episode(args.project, 1)
    paths.frames_dir.mkdir(parents=True, exist_ok=True)

    runner = StepRunner(store=store, paths=paths, validate_frames=False)
    runner._dispatch_path = "probe_nbp_pose_match.py"

    t0 = time.time()
    result = runner.execute_keyframe(
        shot_id=shot_id,
        prompt=args.prompt,
        model=args.model,
        scene_ref_path=look_path,
        pose_ref_path=pose_path,
        aspect_ratio=args.aspect,
        gates=[],
        max_gate_retries=0,
    )
    elapsed = time.time() - t0

    status = "OK" if result.success else "FAIL"
    print()
    print(f"[{status}] {shot_id} -> {result.output_path} (${result.cost_usd or 0:.2f})")
    if result.error:
        print(f"Error: {result.error}")
    print(f"Done in {elapsed:.0f}s")
    return 0 if result.success else 1


if __name__ == "__main__":
    sys.exit(main())
