#!/usr/bin/env python3
"""
test_steprunner_gate2.py — First real test of unified StepRunner + Gate 2.

Regenerates EP001_SH02 keyframe through StepRunner.execute_keyframe()
with Gate 2 (identity check) wired in. Retries up to 3 times on identity
failure.

Expected cost: ~$0.134/attempt (Pro keyframe) + ~$0.039/attempt (Flash gate)
             = ~$0.173/attempt, max ~$0.69 for 4 attempts

Usage:
    python3 tools/test_steprunner_gate2.py
    python3 tools/test_steprunner_gate2.py --dry-run    # Print plan only
    python3 tools/test_steprunner_gate2.py --shot EP001_SH04  # Different shot
"""

import json
import logging
import sys
import time
from pathlib import Path

# Ensure starsend root is on the path
STARSEND_ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(STARSEND_ROOT))

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
    datefmt="%H:%M:%S",
)
logger = logging.getLogger("test_gate2")

from lib.constants import PROJECTS_ROOT
from lib.execution_store import ExecutionStore
from lib.recoil_bridge import load_storyboard, get_character_refs
from lib.asset_manager import AssetManager
from lib.prompt_engine import build_cinematic_prompt
from lib.model_profiles import get_model
from orchestrator.step_types import ProjectPaths
from orchestrator.step_runner import StepRunner, make_identity_gate


def main():
    dry_run = "--dry-run" in sys.argv
    shot_id = "EP001_SH02"

    # Parse --shot flag
    for i, arg in enumerate(sys.argv):
        if arg == "--shot" and i + 1 < len(sys.argv):
            shot_id = sys.argv[i + 1]

    project = "starsend-test"
    episode = 1

    logger.info("=" * 60)
    logger.info("StepRunner + Gate 2 Test")
    logger.info("  Shot: %s", shot_id)
    logger.info("  Project: %s", project)
    logger.info("  Dry run: %s", dry_run)
    logger.info("=" * 60)

    # ── 1. Load plan and find shot ────────────────────────────────────
    data = load_storyboard(episode, project)
    shot = None
    for s in data["shots"]:
        if s.get("shot_id") == shot_id:
            shot = s
            break

    if not shot:
        logger.error("Shot %s not found in plan", shot_id)
        return 1

    chars_in_shot = shot.get("characters_in_shot", [])
    is_env = len(chars_in_shot) == 0

    logger.info("Shot found: type=%s, chars=%s, is_env=%s",
                shot.get("prompt_data", {}).get("shot_type", "?"),
                chars_in_shot, is_env)

    # ── 2. Resolve character refs ─────────────────────────────────────
    ref_paths = []
    for char_key in chars_in_shot:
        paths = get_character_refs(char_key, project)
        ref_paths.extend(paths)
        logger.info("Character refs for %s: %d files", char_key, len(paths))
        for p in paths:
            logger.info("  - %s", p.name)

    # ── 3. Resolve expression ref ─────────────────────────────────────
    assets = AssetManager()
    emotion = shot.get("emotion", "")
    expression_ref = None
    if emotion and not is_env:
        expression_ref = assets.get_expression_ref(emotion)
        if expression_ref:
            logger.info("Expression ref: %s (emotion: %s)", expression_ref.path.name, emotion)
        else:
            logger.warning("No expression ref for emotion: %s", emotion)

    # ── 4. Resolve location ref ───────────────────────────────────────
    location_view_id = shot.get("location_view_id")
    scene_ref_path = None
    if location_view_id:
        loc_dir = PROJECTS_ROOT / project / "output" / "refs" / "locations"
        loc_id = shot.get("asset_data", {}).get("location_id", "")
        if loc_id:
            scene_ref_path = loc_dir / loc_id.lower() / location_view_id
            if scene_ref_path.exists():
                logger.info("Location ref: %s", scene_ref_path.name)
            else:
                logger.warning("Location ref not found: %s", scene_ref_path)
                scene_ref_path = None

    # ── 5. Build prompt ───────────────────────────────────────────────
    prompt = build_cinematic_prompt(
        shot=shot,
        storyboard=data,
        is_env=is_env,
    )
    logger.info("Prompt length: %d chars", len(prompt))
    logger.info("Prompt preview: %s...", prompt[:200])

    # ── 6. Build Gate 2 ──────────────────────────────────────────────
    gate = make_identity_gate(
        ref_paths=ref_paths,
        prompt_skeleton=shot.get("prompt_data", {}).get("prompt_skeleton"),
    )
    logger.info("Gate 2 wired: %d identity refs", len(ref_paths))

    # ── 7. Model selection ────────────────────────────────────────────
    model = get_model("production", "image")
    logger.info("Model: %s (~$0.134/call)", model)

    # ── Cost estimate ─────────────────────────────────────────────────
    est_cost = 0.134 + 0.039  # 1 Pro gen + 1 Flash gate
    est_max = (0.134 + 0.039) * 4  # Up to 4 attempts
    logger.info("Estimated cost: $%.3f per attempt, max $%.2f", est_cost, est_max)

    if dry_run:
        logger.info("DRY RUN — would execute_keyframe with:")
        logger.info("  shot_id: %s", shot_id)
        logger.info("  model: %s", model)
        logger.info("  identity_refs: %d", len(ref_paths))
        logger.info("  expression_refs: %d", 1 if expression_ref else 0)
        logger.info("  scene_ref: %s", "yes" if scene_ref_path else "no")
        logger.info("  gates: [gate_2_identity]")
        logger.info("  max_gate_retries: 3")
        return 0

    # ── 8. Initialize StepRunner ──────────────────────────────────────
    store = ExecutionStore(project=project)
    paths = ProjectPaths.for_episode(project, episode)

    runner = StepRunner(store=store, paths=paths)

    # Force-reset shot state to keyframe_pending
    current_status = (store.get_shot(shot_id) or {}).get("status", "unknown")
    logger.info("Current shot status: %s", current_status)
    store.force_reset_status(shot_id, "keyframe_pending",
                            reason="Gate 2 test — regenerating keyframe with identity check")
    logger.info("Force-reset to keyframe_pending")

    # ── 9. Execute! ───────────────────────────────────────────────────
    start_time = time.time()

    result = runner.execute_keyframe(
        shot_id=shot_id,
        prompt=prompt,
        model=model,
        scene_ref_path=scene_ref_path,
        identity_refs=ref_paths,
        expression_refs=[expression_ref.path] if expression_ref else None,
        aspect_ratio="9:16",
        gates=[gate],
        max_gate_retries=3,
    )

    elapsed = time.time() - start_time

    # ── 10. Report ────────────────────────────────────────────────────
    logger.info("")
    logger.info("=" * 60)
    logger.info("RESULT")
    logger.info("=" * 60)
    logger.info("  success: %s", result.success)
    logger.info("  final_state: %s", result.final_state)
    logger.info("  output_path: %s", result.output_path)
    logger.info("  cost: $%.3f", result.cost_usd)
    logger.info("  model: %s", result.model)
    logger.info("  elapsed: %.1fs", elapsed)

    if result.gate_verdict:
        logger.info("  gate: %s", result.gate_verdict.gate_name)
        logger.info("  gate passed: %s", result.gate_verdict.passed)
        logger.info("  gate reason: %s", result.gate_verdict.reason)

    if result.success:
        logger.info("")
        logger.info("SUCCESS — keyframe passed Gate 2 identity check!")
        logger.info("View: %s", result.output_path)
    else:
        logger.info("")
        logger.info("FAILED — %s", result.error)
        logger.info("Shot left in state: %s", result.final_state)

    return 0 if result.success else 1


if __name__ == "__main__":
    sys.exit(main())
