#!/usr/bin/env python3
"""Corridor sequence v3 — MIME fix + original hand-written prompts + 3 elements.

Root cause: _to_data_uri() was sending JPEG bytes with image/png MIME type.
fal.ai silently dropped the malformed elements. Now fixed via magic-byte detection.

Elements:
  @Element1 = TORCH (hero_beauty.png — now correctly sent as image/jpeg)
  @Element2 = INT_LOWER_DECKS_CORRIDOR (location ref)
  @Element3 = SH05A hero frame (tunnel entrance reveal)

Prompts: Original hand-written corridor narrative (not auto-generated from plan data).
"""

import json
import os
import sys
import time

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from pathlib import Path
from lib.constants import PROJECTS_ROOT
from lib.execution_store import ExecutionStore
from lib.elements import ElementManager, extract_batch_location, _to_data_uri
from orchestrator.step_runner import StepRunner
from orchestrator.step_types import ProjectPaths

PROJECT = "tartarus"
SHOT_IDS = ["EP001_SH01", "EP001_SH03", "EP001_SH04", "EP001_SH05"]
SH05A_HERO = PROJECTS_ROOT / PROJECT / "output/previs/ep_001/shot_005a_take3_31434.png"

# ── Hand-written prompts: continuous corridor take ──
# Original narrative: empty corridor → hook enters → prying panel → reveal tunnel
CUSTOM_SEQUENCE = [
    {
        "index": 1,
        "duration": 4,
        "prompt": (
            "Slow pullback through empty dark industrial corridor @Element2, "
            "handheld drift, dust particles floating in dim flickering light, "
            "shadows deepening"
        ),
    },
    {
        "index": 2,
        "duration": 3,
        "prompt": (
            "CLANG — salvage hook swings into frame from right side, "
            "wedges hard into corroded wall panel, impact sparks fly, "
            "metal scraping metal @Element2"
        ),
    },
    {
        "index": 3,
        "duration": 4,
        "prompt": (
            "@Element1 grimy hands gripping salvage hook handle, "
            "prying hard against buckled metal panel, straining muscles, "
            "panel groaning and bending, rust flaking off @Element2"
        ),
    },
    {
        "index": 4,
        "duration": 4,
        "prompt": (
            "@Element1 panel wrenches free and slides down with a metallic screech, "
            "revealing @Element3 dark tunnel entrance beyond, faint blue glow from within, "
            "dust billowing out @Element2"
        ),
    },
]

# Load plan (needed for batch metadata / shot IDs)
plan_path = PROJECTS_ROOT / PROJECT / "state/visual/plans/ep_001_plan.json"
plan = json.loads(plan_path.read_text())
shots_map = {s["shot_id"]: s for s in plan["shots"]}

batch = []
for sid in SHOT_IDS:
    shot = shots_map[sid]
    batch.append(shot)

# Build elements: TORCH (Element1) + corridor (Element2)
char_ids = ["TORCH"]
location_id = extract_batch_location(batch)
payload, has_location, n_standard_elements = ElementManager.build_elements_with_info(
    char_ids, PROJECT, location_id=location_id,
)

print(f"Location: {location_id}")

# Verify MIME types are correct now
for i, elem in enumerate(payload.get("elements", [])):
    frontal_prefix = elem.get("frontal_image_url", "")[:50]
    mime = frontal_prefix.split(";")[0].split(":")[1] if ":" in frontal_prefix else "?"
    label = "char/TORCH" if i == 0 else f"location/{location_id}"
    print(f"  Element {i+1}: {label} MIME={mime}")

# Add SH05A hero as Element3
if not SH05A_HERO.exists():
    print(f"ERROR: SH05A hero not found at {SH05A_HERO}")
    sys.exit(1)

sh05a_uri = _to_data_uri(SH05A_HERO)
if not sh05a_uri:
    print("ERROR: Failed to encode SH05A hero")
    sys.exit(1)

payload["elements"].append({
    "frontal_image_url": sh05a_uri,
    "reference_image_urls": [],
})
total_elements = len(payload["elements"])

sh05a_mime = sh05a_uri.split(";")[0].split(":")[1]
print(f"  Element {total_elements}: SH05A reveal MIME={sh05a_mime}")
print(f"Elements: {total_elements} total")

# Print sequence
print(f"\nSegments: {len(CUSTOM_SEQUENCE)}, total: {sum(s['duration'] for s in CUSTOM_SEQUENCE)}s")
for seg in CUSTOM_SEQUENCE:
    print(f"\n  Seg {seg['index']} ({seg['duration']}s): {seg['prompt']}")

# Find start frame for SH01
store = ExecutionStore(PROJECT)
start_frame = None
sh01_state = store.get_shot("EP001_SH01")
if sh01_state:
    hero = sh01_state.get("gate_results", {}).get("hero_frame")
    if hero:
        p = PROJECTS_ROOT / PROJECT / hero
        if p.exists():
            start_frame = p

if not start_frame:
    frames_dir = PROJECTS_ROOT / PROJECT / "output" / "previs" / "ep_001"
    candidates = sorted(frames_dir.glob("*shot_001*"), key=lambda p: p.stat().st_mtime, reverse=True)
    if candidates:
        start_frame = candidates[0]

print(f"\nStart frame: {start_frame.name if start_frame else 'NONE'}")

# Execute
paths = ProjectPaths.for_episode(PROJECT, 1)
runner = StepRunner(store=store, paths=paths)

t0 = time.time()
results = runner.execute_multi_shot(
    batch=batch,
    multi_prompt_sequence=CUSTOM_SEQUENCE,
    model="kling-o3",
    start_frame=start_frame,
    elements_payload=payload,
)
elapsed = time.time() - t0

print()
for r in results:
    status = "OK" if r.success else "FAIL"
    print(f"  [{status}] {r.shot_id} -> {r.output_path} (${r.cost_usd:.2f})")
    if r.error:
        print(f"    Error: {r.error}")
print(f"\nDone in {elapsed:.0f}s, ${sum(r.cost_usd for r in results):.2f} total")
