#!/usr/bin/env python3
"""Corridor sequence v4 — All 4 element slots, shorter prompts.

Fixes from v3:
  - V3 endpoint (not O3) — elements actually work now
  - MIME detection via magic bytes
  - Shorter action-focused prompts (fal.ai best practice)
  - All 4 element slots used

Elements:
  @Element1 = TORCH (hero_beauty.png + 3 turnaround refs)
  @Element2 = SALVAGE_HOOK prop (salvage_hook_hero.png)
  @Element3 = SH05A tunnel reveal frame
  @Element4 = INT_LOWER_DECKS_CORRIDOR (hero ref)

Prompts: Short corridor narrative — empty corridor → hook → prying → tunnel reveal.
"""

import json
import os
import sys
import time

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from pathlib import Path
from lib.constants import PROJECTS_ROOT
from lib.execution_store import ExecutionStore
from lib.elements import ElementManager, extract_batch_location, _to_data_uri
from orchestrator.step_runner import StepRunner
from orchestrator.step_types import ProjectPaths

PROJECT = "tartarus"
SHOT_IDS = ["EP001_SH01", "EP001_SH03", "EP001_SH04", "EP001_SH05"]
SH05A_HERO = PROJECTS_ROOT / PROJECT / "output/previs/ep_001/shot_005a_take3_31434.png"
HOOK_REF = PROJECTS_ROOT / PROJECT / "output/refs/props/salvage_hook/salvage_hook_hero.png"

# ── Shorter prompts: action-focused, natural @Element weaving ──
CUSTOM_SEQUENCE = [
    {
        "index": 1,
        "duration": 4,
        "prompt": (
            "Slow pullback through empty dark corridor @Element4, "
            "dust floating in flickering light, shadows deepening"
        ),
    },
    {
        "index": 2,
        "duration": 3,
        "prompt": (
            "The tool from @Element2 swings into frame from the right, "
            "wedges into corroded wall panel, sparks fly @Element4"
        ),
    },
    {
        "index": 3,
        "duration": 4,
        "prompt": (
            "@Element1 grips @Element2, prying against buckled metal panel, "
            "straining, panel bending, rust flaking @Element4"
        ),
    },
    {
        "index": 4,
        "duration": 4,
        "prompt": (
            "@Element1 wrenches panel free, revealing @Element3 "
            "dark tunnel entrance, faint blue glow, dust billowing @Element4"
        ),
    },
]

# Load plan (needed for batch metadata / shot IDs)
plan_path = PROJECTS_ROOT / PROJECT / "state/visual/plans/ep_001_plan.json"
plan = json.loads(plan_path.read_text())
shots_map = {s["shot_id"]: s for s in plan["shots"]}

batch = []
for sid in SHOT_IDS:
    shot = shots_map[sid]
    batch.append(shot)

# Build Element 1: TORCH (with turnaround refs)
char_ids = ["TORCH"]
payload, has_location, n_standard_elements = ElementManager.build_elements_with_info(
    char_ids, PROJECT, location_id=None,  # We'll add location manually as Element4
)

print(f"Element 1: TORCH")
for i, elem in enumerate(payload.get("elements", [])):
    frontal_prefix = elem.get("frontal_image_url", "")[:50]
    mime = frontal_prefix.split(";")[0].split(":")[1] if ":" in frontal_prefix else "?"
    n_refs = len(elem.get("reference_image_urls", []))
    frontal_size = len(elem.get("frontal_image_url", "")) * 3 // 4 // 1024
    print(f"  MIME={mime}, {n_refs} refs, frontal={frontal_size}KB")

# Element 2: Salvage hook prop
if not HOOK_REF.exists():
    print(f"ERROR: Salvage hook ref not found at {HOOK_REF}")
    sys.exit(1)

hook_uri = _to_data_uri(HOOK_REF)
if not hook_uri:
    print("ERROR: Failed to encode salvage hook")
    sys.exit(1)

payload["elements"].append({
    "frontal_image_url": hook_uri,
    "reference_image_urls": [],
})
hook_mime = hook_uri.split(";")[0].split(":")[1]
hook_size = len(hook_uri) * 3 // 4 // 1024
print(f"Element 2: SALVAGE_HOOK MIME={hook_mime}, frontal={hook_size}KB")

# Element 3: SH05A tunnel reveal
if not SH05A_HERO.exists():
    print(f"ERROR: SH05A hero not found at {SH05A_HERO}")
    sys.exit(1)

sh05a_uri = _to_data_uri(SH05A_HERO)
if not sh05a_uri:
    print("ERROR: Failed to encode SH05A hero")
    sys.exit(1)

payload["elements"].append({
    "frontal_image_url": sh05a_uri,
    "reference_image_urls": [],
})
sh05a_mime = sh05a_uri.split(";")[0].split(":")[1]
sh05a_size = len(sh05a_uri) * 3 // 4 // 1024
print(f"Element 3: SH05A reveal MIME={sh05a_mime}, frontal={sh05a_size}KB")

# Element 4: Corridor location (hero only — smaller file to avoid 422)
from lib.recoil_bridge import get_location_refs
location_id = extract_batch_location(batch)
loc_refs = get_location_refs(location_id, project=PROJECT)
# Pick the hero ref (smallest, most likely to work)
loc_hero = None
for r in loc_refs:
    if "hero" in r.name.lower():
        loc_hero = r
        break
if not loc_hero and loc_refs:
    # Fallback: pick smallest file
    loc_hero = min(loc_refs, key=lambda p: p.stat().st_size)

if loc_hero:
    loc_uri = _to_data_uri(loc_hero)
    if loc_uri:
        payload["elements"].append({
            "frontal_image_url": loc_uri,
            "reference_image_urls": [],
        })
        loc_mime = loc_uri.split(";")[0].split(":")[1]
        loc_size = len(loc_uri) * 3 // 4 // 1024
        print(f"Element 4: {location_id} ({loc_hero.name}) MIME={loc_mime}, frontal={loc_size}KB")
    else:
        print(f"WARNING: Failed to encode location ref, continuing without it")
else:
    print(f"WARNING: No location refs found for {location_id}, continuing without it")

total_elements = len(payload["elements"])
print(f"\nElements: {total_elements} total (max 4)")

# Print sequence
print(f"Segments: {len(CUSTOM_SEQUENCE)}, total: {sum(s['duration'] for s in CUSTOM_SEQUENCE)}s")
for seg in CUSTOM_SEQUENCE:
    print(f"\n  Seg {seg['index']} ({seg['duration']}s): {seg['prompt']}")

# Find start frame for SH01
store = ExecutionStore(PROJECT)
start_frame = None
sh01_state = store.get_shot("EP001_SH01")
if sh01_state:
    hero = sh01_state.get("gate_results", {}).get("hero_frame")
    if hero:
        p = PROJECTS_ROOT / PROJECT / hero
        if p.exists():
            start_frame = p

if not start_frame:
    frames_dir = PROJECTS_ROOT / PROJECT / "output" / "previs" / "ep_001"
    candidates = sorted(frames_dir.glob("*shot_001*"), key=lambda p: p.stat().st_mtime, reverse=True)
    if candidates:
        start_frame = candidates[0]

print(f"\nStart frame: {start_frame.name if start_frame else 'NONE'}")

# Execute
paths = ProjectPaths.for_episode(PROJECT, 1)
runner = StepRunner(store=store, paths=paths)

t0 = time.time()
results = runner.execute_multi_shot(
    batch=batch,
    multi_prompt_sequence=CUSTOM_SEQUENCE,
    model="kling-o3",
    start_frame=start_frame,
    elements_payload=payload,
)
elapsed = time.time() - t0

print()
for r in results:
    status = "OK" if r.success else "FAIL"
    print(f"  [{status}] {r.shot_id} -> {r.output_path} (${r.cost_usd:.2f})")
    if r.error:
        print(f"    Error: {r.error}")
print(f"\nDone in {elapsed:.0f}s, ${sum(r.cost_usd for r in results):.2f} total")
