#!/usr/bin/env python3
"""
batch_gate2_test.py — Regenerate all EP001 character shots through StepRunner + Gate 2.

Runs sequentially (Gemini rate limits). Reports per-shot results with retry counts.

Usage:
    python3 tools/batch_gate2_test.py
    python3 tools/batch_gate2_test.py --dry-run
"""

import json
import logging
import sys
import time
from pathlib import Path
from typing import Optional

PIPELINE_ROOT = Path(__file__).parent.parent
RECOIL_ROOT = PIPELINE_ROOT.parent
# Bootstrap RECOIL_ROOT first so `from core.X` resolves to recoil/core/
# (top-level), not recoil/pipeline/core/. Then call ensure_pipeline_importable()
# to add PIPELINE_ROOT in the right order.
if str(RECOIL_ROOT) not in sys.path:
    sys.path.insert(0, str(RECOIL_ROOT))
from recoil.core.paths import ensure_pipeline_importable
ensure_pipeline_importable()

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
    datefmt="%H:%M:%S",
)
logger = logging.getLogger("batch_gate2")

from recoil.core.paths import projects_root, STATE_NAMESPACE
from recoil.core.paths import ProjectPaths as CoreProjectPaths
from recoil.execution.execution_store import ExecutionStore
from recoil.pipeline._lib.recoil_bridge import load_storyboard, get_character_refs
from recoil.execution.asset_manager import AssetManager
from recoil.pipeline._lib.prompt_engine import build_cinematic_prompt
from recoil.core.model_profiles import get_model
from recoil.execution.step_types import ProjectPaths
from recoil.execution.step_runner import StepRunner, make_identity_gate
from recoil.pipeline.core.dispatch import dispatch
from recoil.pipeline.core.dispatch_context import DispatchContext
from recoil.pipeline.core.cost import read_cost_from_result


def build_gate_wardrobe_spec(bible: dict, char_key: str) -> str:
    """Extract a gate-ready wardrobe+appearance spec from the bible.

    Combines wardrobe items with structural appearance properties (hair,
    distinguishing marks) while stripping pose-specific language.
    """
    char = bible.get("characters", {}).get(char_key, {})
    parts = []

    # Wardrobe items
    wardrobe = char.get("wardrobe", {})
    phase = wardrobe.get("default", "")
    if isinstance(phase, dict):
        phase = phase.get("description", "")
    if phase:
        parts.append(f"Wardrobe: {phase}")

    # Structural appearance from visual_description
    visual = char.get("visual_description", "")
    if visual:
        # Extract hair, marks, skin — structural properties only
        import re
        for line in visual.replace(". ", ".\n").split("\n"):
            low = line.lower()
            if any(kw in low for kw in ["hair", "scar", "freckle", "tattoo", "mark", "skin"]):
                parts.append(line.strip())

    return "\n".join(parts)


def _get_shot_type(shot: dict) -> str:
    """Extract shot type from plan data."""
    pd = shot.get("prompt_data", {})
    return pd.get("prompt_skeleton", {}).get("shot_type", pd.get("shot_type", ""))


def _find_latest_keyframe(shot_id: str, frames_dir: Path) -> Optional[Path]:
    """Find the most recent keyframe file for a shot ID.

    Checks both naming conventions:
      - STARSEND-TEST_EP001_S01.png (batch runner format)
      - STA_EP001_S00_shot_EP001_SH01.png (older format)
    Returns the newest matching file, or None.
    """
    import re
    # Normalize: EP001_SH01 → S01, EP001_SH02A → S02A
    short_id = shot_id.replace("EP001_SH", "S")

    # Build patterns that match exactly (word boundary at end to avoid S02 matching S02A)
    patterns = [
        re.compile(rf"_EP001_{re.escape(short_id)}(?:_take\d+)?$"),  # *_EP001_S01 or *_EP001_S01_take2
        re.compile(rf"_shot_{re.escape(shot_id)}$"),                  # *_shot_EP001_SH01
    ]

    candidates = []
    if not frames_dir.exists():
        return None
    for p in frames_dir.iterdir():
        if not p.suffix.lower() == ".png":
            continue
        name = p.stem
        if any(pat.search(name) for pat in patterns):
            candidates.append(p)

    if not candidates:
        return None
    # Return the newest file
    return max(candidates, key=lambda p: p.stat().st_mtime)


def main():
    dry_run = "--dry-run" in sys.argv
    project = "starsend-test"
    episode = 1

    # Load plan + bible
    data = load_storyboard(episode, project)
    bible_path = CoreProjectPaths.for_project(project).global_bible_path
    bible = json.load(open(bible_path)) if bible_path.exists() else {}
    assets = AssetManager()
    model = get_model("production", "image")

    # All shots in sequence (for continuity lookups)
    all_shots = data["shots"]
    frames_dir = projects_root() / project / "output" / "frames" / f"ep_{episode:03d}"

    # Find character shots (non-ENV)
    char_shots = []
    for s in all_shots:
        chars = s.get("characters_in_shot", [])
        if chars:
            char_shots.append(s)

    logger.info("=" * 60)
    logger.info("Batch StepRunner + Gate 2")
    logger.info("  %d character shots to generate", len(char_shots))
    logger.info("  Model: %s", model)
    logger.info("  Max cost: ~$%.2f", len(char_shots) * 0.173 * 4)
    logger.info("=" * 60)

    if dry_run:
        for s in char_shots:
            sid = s["shot_id"]
            chars = s.get("characters_in_shot", [])
            shot_type = _get_shot_type(s)
            # Show continuity refs that would be used
            s_idx = all_shots.index(s)
            cont_ref = None
            for prev in reversed(all_shots[:s_idx]):
                if any(c in prev.get("characters_in_shot", []) for c in chars):
                    cont_ref = _find_latest_keyframe(prev["shot_id"], frames_dir)
                    if cont_ref:
                        break
            prec_ref = None
            prec_type = "?"
            if s_idx > 0:
                prec_ref = _find_latest_keyframe(all_shots[s_idx - 1]["shot_id"], frames_dir)
                prec_type = _get_shot_type(all_shots[s_idx - 1])
            logger.info("  %s: type=%s, chars=%s", sid, shot_type, chars)
            logger.info("    continuity_ref=%s", cont_ref.name if cont_ref else "none")
            logger.info("    preceding=%s (%s)", prec_ref.name if prec_ref else "none", prec_type)
        logger.info("DRY RUN — no generation")
        return 0

    # Initialize
    store = ExecutionStore(project=project)
    paths = ProjectPaths.for_episode(project, episode)
    runner = StepRunner(store=store, paths=paths)

    ctx = DispatchContext(
        caller_id="batch_gate2_test",
        step_runner=runner,
        project=project,
        episode=episode,
    )

    # Results tracking
    results = []
    total_cost = 0.0
    total_attempts = 0
    start_time = time.time()

    for i, shot in enumerate(char_shots):
        shot_id = shot["shot_id"]
        chars = shot.get("characters_in_shot", [])
        shot_type = shot.get("prompt_data", {}).get("shot_type", "?")

        logger.info("")
        logger.info("─" * 60)
        logger.info("[%d/%d] %s (%s)", i + 1, len(char_shots), shot_id, shot_type)
        logger.info("─" * 60)

        # Resolve refs
        ref_paths = []
        for char_key in chars:
            ref_paths.extend(get_character_refs(char_key, project))

        emotion = shot.get("emotion", "")
        expression_ref = None
        if emotion:
            expression_ref = assets.get_expression_ref(emotion)

        # Location ref
        location_view_id = shot.get("location_view_id")
        scene_ref_path = None
        if location_view_id:
            loc_id = shot.get("asset_data", {}).get("location_id", "")
            if loc_id:
                loc_dir = projects_root() / project / "output" / "refs" / "locations"
                candidate = loc_dir / loc_id.lower() / location_view_id
                if candidate.exists():
                    scene_ref_path = candidate

        # Build prompt
        prompt = build_cinematic_prompt(shot=shot, storyboard=data, is_env=False, bible=bible)

        # Build wardrobe spec from bible for all characters in shot
        wardrobe_specs = []
        for char_key in chars:
            spec = build_gate_wardrobe_spec(bible, char_key)
            if spec:
                wardrobe_specs.append(f"[{char_key}]\n{spec}")
        wardrobe_description = "\n\n".join(wardrobe_specs) if wardrobe_specs else None

        # Resolve continuity refs (Layers 3 & 4)
        current_shot_type = _get_shot_type(shot)

        # Layer 3: Character's last approved keyframe (scan backwards for same char)
        continuity_ref_path = None
        shot_idx = all_shots.index(shot)
        for prev_shot in reversed(all_shots[:shot_idx]):
            prev_chars = prev_shot.get("characters_in_shot", [])
            if any(c in prev_chars for c in chars):
                found = _find_latest_keyframe(prev_shot["shot_id"], frames_dir)
                if found:
                    continuity_ref_path = found
                    break

        # Layer 4: Preceding shot keyframe (immediately prior, any character)
        preceding_shot_path = None
        preceding_shot_type = None
        if shot_idx > 0:
            prev = all_shots[shot_idx - 1]
            found = _find_latest_keyframe(prev["shot_id"], frames_dir)
            if found:
                preceding_shot_path = found
                preceding_shot_type = _get_shot_type(prev)

        logger.info("  Continuity ref: %s", continuity_ref_path.name if continuity_ref_path else "none")
        logger.info("  Preceding shot: %s (%s)", preceding_shot_path.name if preceding_shot_path else "none", preceding_shot_type or "?")

        # Build gate
        gate = make_identity_gate(
            ref_paths=ref_paths,
            prompt_skeleton=shot.get("prompt_data", {}).get("prompt_skeleton"),
            wardrobe_description=wardrobe_description,
            continuity_ref_path=continuity_ref_path,
            preceding_shot_path=preceding_shot_path,
            preceding_shot_type=preceding_shot_type,
            current_shot_type=current_shot_type,
        )

        # Force-reset to keyframe_pending
        current = (store.get_shot(shot_id) or {}).get("status", "unknown")
        store.force_reset_status(
            shot_id, "keyframe_pending",
            reason=f"Batch Gate 2 test — was {current}"
        )

        # Generate!
        t0 = time.time()
        receipt = dispatch(
            "image_t2i",
            {
                "shot_id": shot_id,
                "prompt": prompt,
                "model": model,
                "scene_ref_path": scene_ref_path,
                "identity_refs": ref_paths,
                "expression_refs": [expression_ref.path] if expression_ref else None,
                "aspect_ratio": "9:16",
                "gates": [gate],
                "max_gate_retries": 3,
            },
            context=ctx,
        )
        result = receipt.run_result
        elapsed = time.time() - t0
        cost_usd = read_cost_from_result(result)
        final_state = result.metadata.get("final_state", "")
        gate_verdict = result.metadata.get("gate_verdict")
        total_cost += cost_usd

        # Count attempts from the shot store
        shot_data = store.get_shot(shot_id) or {}
        attempts = shot_data.get("attempts", 1)
        total_attempts += attempts

        gate_status = ""
        if gate_verdict:
            gate_status = "PASS" if gate_verdict.passed else "FAIL"
            if not gate_verdict.passed:
                gate_status += f" — {gate_verdict.reason[:80]}"

        status_icon = "PASS" if result.success else "FAIL"
        logger.info(
            "  %s | %s | $%.3f | %d attempt(s) | %.1fs | gate: %s",
            status_icon, final_state, cost_usd,
            attempts, elapsed, gate_status or "n/a"
        )

        results.append({
            "shot_id": shot_id,
            "shot_type": shot_type,
            "success": result.success,
            "state": final_state,
            "cost": cost_usd,
            "attempts": attempts,
            "elapsed": elapsed,
            "output": result.output_path,
            "gate_passed": gate_verdict.passed if gate_verdict else None,
        })

        # Brief pause between shots to avoid rate limiting
        if i < len(char_shots) - 1:
            time.sleep(3)

    # Summary
    total_elapsed = time.time() - start_time
    passed = sum(1 for r in results if r["success"])
    failed = len(results) - passed

    logger.info("")
    logger.info("=" * 60)
    logger.info("BATCH COMPLETE")
    logger.info("=" * 60)
    logger.info("")
    logger.info("  Shots: %d/%d passed, %d failed", passed, len(results), failed)
    logger.info("  Total cost: $%.3f", total_cost)
    logger.info("  Total attempts: %d (%.1f avg)", total_attempts, total_attempts / len(results))
    logger.info("  Elapsed: %.0fs (%.1fs avg/shot)", total_elapsed, total_elapsed / len(results))
    logger.info("")

    # Per-shot table
    logger.info("  %-14s %-4s %-6s %-8s %-5s %-8s", "Shot", "Type", "Result", "Cost", "Tries", "Gate")
    logger.info("  %s", "-" * 52)
    for r in results:
        gate_str = "PASS" if r["gate_passed"] else ("FAIL" if r["gate_passed"] is False else "—")
        result_str = "OK" if r["success"] else "FAIL"
        logger.info(
            "  %-14s %-4s %-6s $%-7.3f %-5d %-8s",
            r["shot_id"], r["shot_type"], result_str, r["cost"], r["attempts"], gate_str
        )

    if failed:
        logger.info("")
        logger.info("FAILED SHOTS:")
        for r in results:
            if not r["success"]:
                logger.info("  %s: %s", r["shot_id"], r["state"])

    return 0 if failed == 0 else 1


if __name__ == "__main__":
    sys.exit(main())
