#!/usr/bin/env python3
"""generate_keyframes.py — Generate keyframes via StepRunner (model-agnostic).

Uses StepRunner.execute_keyframe() which supports any model client that
implements generate_keyframe(). Unlike test_via_steprunner.py (video only)
or the Production Console (NBP-hardcoded), this tool routes through
StepRunner with configurable model.

Usage:
    python3 tools/generate_keyframes.py --project afterimage-anime --episode 1 --model seedream-v4.5
    python3 tools/generate_keyframes.py --project afterimage --episode 1 --model gemini-3-pro-image-preview
    python3 tools/generate_keyframes.py --project afterimage-anime --episode 1 --shots EP001_SH01,EP001_SH05 --model seedream-v4.5
"""

import argparse
import json
import os
import sys
import time

_HERE = os.path.dirname(os.path.abspath(__file__))
_RECOIL_ROOT = os.path.dirname(os.path.dirname(_HERE))
if _RECOIL_ROOT not in sys.path:
    sys.path.insert(0, _RECOIL_ROOT)

from recoil.core.paths import ensure_pipeline_importable, projects_root, ProjectPaths  # noqa: E402

ensure_pipeline_importable()


def main():
    parser = argparse.ArgumentParser(description="Generate keyframes via StepRunner")
    parser.add_argument("--project", required=True)
    parser.add_argument("--episode", type=int, required=True)
    parser.add_argument("--model", default="seedream-v4.5")
    parser.add_argument("--shots", help="Comma-separated shot IDs (default: all)")
    parser.add_argument("--dry-run", action="store_true", help="Print plan without generating")
    parser.add_argument("--no-location-refs", action="store_true", help="Skip location scene refs")
    parser.add_argument("--style-prefix", help="Prepend style anchor to all prompts")
    args = parser.parse_args()

    project_dir = projects_root() / args.project
    if not project_dir.exists():
        print(f"Project not found: {project_dir}")
        sys.exit(1)

    # Load plan
    pp = ProjectPaths.for_project(args.project)
    plan_path = pp.plans_dir / f"ep_{args.episode:03d}_plan.json"
    if not plan_path.exists():
        print(f"Plan not found: {plan_path}")
        sys.exit(1)

    plan = json.loads(plan_path.read_text())
    shots = plan.get("shots", [])

    # Filter shots if specified
    if args.shots:
        target_ids = set(args.shots.split(","))
        shots = [s for s in shots if s["shot_id"] in target_ids]

    print(f"Project: {args.project}")
    print(f"Episode: {args.episode}")
    print(f"Model: {args.model}")
    print(f"Shots: {len(shots)}")

    # Load casting state for ref paths
    casting_path = pp.casting_state_path
    casting = json.loads(casting_path.read_text()) if casting_path.exists() else {}

    # Build ref lookup (extension-tolerant via resolve_ref_path)
    from recoil.core.paths import resolve_ref_path

    char_refs = {}
    # Support both nested {"characters": {...}} and flat {char_id: {...}} structures
    char_entries = casting.get("characters", casting)
    for char_name, char_data in char_entries.items():
        paths = []
        hero = char_data.get("hero_path")
        if hero:
            full = resolve_ref_path(project_dir, hero)
            if full:
                paths.append(full)
        for view in ["front", "three_quarter"]:
            tp = (char_data.get("turnaround_paths") or {}).get(view)
            if tp:
                full = resolve_ref_path(project_dir, tp)
                if full:
                    paths.append(full)
        char_refs[char_name] = paths

    loc_refs = {}
    for loc_name, loc_data in casting.get("locations", {}).items():  # locations always nested
        hero = loc_data.get("hero_path")
        if hero:
            full = resolve_ref_path(project_dir, hero)
            if full:
                loc_refs[loc_name] = full

    if args.dry_run:
        for shot in shots:
            sid = shot["shot_id"]
            prompt_data = shot.get("prompt_data", {})
            prompt = prompt_data.get("compiled_prompt", prompt_data.get("prompt_skeleton", {}).get("action", ""))
            chars = shot.get("asset_data", {}).get("characters", [])
            loc = shot.get("asset_data", {}).get("location_id", "")
            print(f"  {sid}: {prompt[:60]}... chars={chars} loc={loc}")
        return

    # Import StepRunner
    from recoil.execution.step_runner import StepRunner
    from recoil.execution.step_types import ProjectPaths as ExecProjectPaths
    from recoil.execution.execution_store import ExecutionStore
    from recoil.pipeline.core.dispatch import dispatch
    from recoil.pipeline.core.dispatch_context import DispatchContext
    from recoil.pipeline.core.cost import read_cost_from_result

    store = ExecutionStore(args.project)
    paths = ExecProjectPaths.for_episode(args.project, args.episode)
    paths.frames_dir.mkdir(parents=True, exist_ok=True)
    runner = StepRunner(store, paths, episode=args.episode)

    ctx = DispatchContext(
        caller_id="generate_keyframes",
        step_runner=runner,
        project=args.project,
        episode=args.episode,
    )

    total_cost = 0.0
    ok = 0
    t0 = time.time()

    # Style prefix — anchors anime consistency across all shots
    style_prefix = args.style_prefix or ""

    for shot in shots:
        sid = shot["shot_id"]
        prompt_data = shot.get("prompt_data", {})
        # Try compiled prompt, fall back to skeleton
        prompt = prompt_data.get("compiled_prompt", "")
        if not prompt:
            skeleton = prompt_data.get("prompt_skeleton", {})
            prompt = ", ".join(v for v in skeleton.values() if v)
        if not prompt:
            prompt = shot.get("description", shot.get("action", ""))

        # Prepend style anchor if provided
        if style_prefix:
            prompt = style_prefix + " " + prompt

        # Gather refs
        asset_data = shot.get("asset_data", {})
        chars_raw = asset_data.get("characters", [])
        loc_id = asset_data.get("location_id", "")

        identity = []
        for c in chars_raw:
            cid = c.get("char_id", c) if isinstance(c, dict) else c
            identity.extend(char_refs.get(cid.upper(), []))

        # Scene ref — skip if --no-location-refs
        scene_ref = None if args.no_location_refs else loc_refs.get(loc_id)

        print(f"  {sid}: {len(identity)} identity refs, scene={'SKIP' if args.no_location_refs else (loc_id or 'none')}")
        receipt = dispatch(
            "image_t2i",
            {
                "shot_id": sid,
                "prompt": prompt,
                "model": args.model,
                "scene_ref_path": scene_ref,
                "identity_refs": identity,
                "aspect_ratio": "9:16",
            },
            context=ctx,
        )
        result = receipt.run_result

        if result.success:
            ok += 1
            cost_usd = read_cost_from_result(result)
            total_cost += cost_usd
            print(f"  [OK] {sid} → {result.output_path} (${cost_usd:.3f})")
        else:
            print(f"  [FAIL] {sid}: {result.error}")

    elapsed = time.time() - t0
    print(f"\nDone: {ok}/{len(shots)} in {elapsed:.0f}s, ${total_cost:.2f} total")


if __name__ == "__main__":
    main()
