#!/usr/bin/env python3
"""Re-run 4-segment corridor sequence with fixed TORCH hero + SH05A reveal element.

Elements:
  @Element1 = TORCH (hero_beauty.png + turnarounds)
  @Element2 = INT_LOWER_DECKS_CORRIDOR (location ref)
  @Element3 = SH05A hero frame (tunnel entrance reveal)
"""

import json
import os
import sys
import time

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from pathlib import Path
from lib.constants import PROJECTS_ROOT
from lib.execution_store import ExecutionStore
from lib.elements import ElementManager, extract_batch_location, _to_data_uri
from lib.prompt_engine import build_multi_prompt_sequence
from orchestrator.step_runner import StepRunner
from orchestrator.step_types import ProjectPaths

PROJECT = "tartarus"
SHOT_IDS = ["EP001_SH01", "EP001_SH03", "EP001_SH04", "EP001_SH05"]
SH05A_HERO = PROJECTS_ROOT / PROJECT / "output/previs/ep_001/shot_005a_take3_31434.png"

# Load plan
plan_path = PROJECTS_ROOT / PROJECT / "state/visual/plans/ep_001_plan.json"
plan = json.loads(plan_path.read_text())
shots_map = {s["shot_id"]: s for s in plan["shots"]}

batch = []
for sid in SHOT_IDS:
    shot = shots_map[sid]
    shot["_api_duration"] = max(3, shot.get("routing_data", {}).get("target_editorial_duration_s", 5))
    batch.append(shot)

# Build elements: TORCH (Element1) + corridor (Element2)
char_ids = ["TORCH"]
location_id = extract_batch_location(batch)
payload, has_location, n_standard_elements = ElementManager.build_elements_with_info(
    char_ids, PROJECT, location_id=location_id
)

print(f"Location: {location_id}")

# Add SH05A hero as Element3 (tunnel entrance reveal)
if not SH05A_HERO.exists():
    print(f"ERROR: SH05A hero not found at {SH05A_HERO}")
    sys.exit(1)

sh05a_uri = _to_data_uri(SH05A_HERO)
if not sh05a_uri:
    print("ERROR: Failed to encode SH05A hero")
    sys.exit(1)

payload["elements"].append({
    "frontal_image_url": sh05a_uri,
    "reference_image_urls": [],
})
total_elements = len(payload["elements"])

print(f"Elements: {total_elements} total, location={has_location}")
for i, elem in enumerate(payload["elements"]):
    frontal = "OK" if elem.get("frontal_image_url") else "MISSING"
    n_refs = len(elem.get("reference_image_urls", []))
    if i == 0:
        label = "char/TORCH"
    elif has_location and i == 1:
        label = f"location/{location_id}"
    else:
        label = "SH05A reveal"
    print(f"  Element {i+1}: {label} frontal={frontal}, {n_refs} additional refs")

# Build sequence — standard injection uses only TORCH + corridor (n_standard_elements)
sequence = build_multi_prompt_sequence(
    batch,
    batch_char_ids=char_ids,
    has_location_element=has_location,
    total_elements=n_standard_elements,
)

# Inject @Element3 into last segment (panel reveal → tunnel entrance)
last_seg = sequence[-1]
last_seg["prompt"] = last_seg["prompt"].rstrip() + f" @Element{total_elements}"

print(f"Segments: {len(sequence)}, total: {sum(s['duration'] for s in sequence)}s")
for i, seg in enumerate(sequence):
    print(f"\n  Seg {i+1} ({seg['duration']}s): {seg['prompt']}")

# Find start frame for SH01
store = ExecutionStore(PROJECT)
start_frame = None
sh01_state = store.get_shot("EP001_SH01")
if sh01_state:
    hero = sh01_state.get("gate_results", {}).get("hero_frame")
    if hero:
        p = PROJECTS_ROOT / PROJECT / hero
        if p.exists():
            start_frame = p

if not start_frame:
    frames_dir = PROJECTS_ROOT / PROJECT / "output" / "previs" / "ep_001"
    candidates = sorted(frames_dir.glob("*shot_001*"), key=lambda p: p.stat().st_mtime, reverse=True)
    if candidates:
        start_frame = candidates[0]

print(f"\nStart frame: {start_frame.name if start_frame else 'NONE'}")

# Execute
paths = ProjectPaths.for_episode(PROJECT, 1)
runner = StepRunner(store=store, paths=paths)

t0 = time.time()
results = runner.execute_multi_shot(
    batch=batch,
    multi_prompt_sequence=sequence,
    model="kling-o3",
    start_frame=start_frame,
    elements_payload=payload,
)
elapsed = time.time() - t0

print()
for r in results:
    status = "OK" if r.success else "FAIL"
    print(f"  [{status}] {r.shot_id} -> {r.output_path} (${r.cost_usd:.2f})")
    if r.error:
        print(f"    Error: {r.error}")
print(f"\nDone in {elapsed:.0f}s, ${sum(r.cost_usd for r in results):.2f} total")
