#!/usr/bin/env python3
"""
pipeline.py — Main generation orchestrator for Starsend.

Phase A refactoring: Router-pipeline architecture with strategy pattern.

Pipeline strategies:
  StillPipeline   — 3-pass image generation (simple/standard/complex tiers)
  I2VPipeline     — Keyframe → I2V via Kling (stub — Phase B)
  T2VPipeline     — Text-to-video via Kling/SeedDance/Veo (stub — Phase B)
  MultiShotPipeline — Scene batch via SeedDance (stub — Phase B)

The main Pipeline class is a thin orchestrator:
  Load data → plan scenes → route shots → group multi-shot batches →
  dispatch to strategies → track costs → enforce budget
"""

import json
import logging
import time
import uuid
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from io import BytesIO
from pathlib import Path
from typing import Optional

# ── Optional imports (work in dry_run without these installed) ────────

try:
    from PIL import Image

    _HAS_PIL = True
except ImportError:
    Image = None
    _HAS_PIL = False

# ── Project imports ───────────────────────────────────────────────────

from recoil.pipeline._lib.recoil_bridge import (
    load_storyboard,
    load_breakdown,
    load_project_config,
    get_character_refs,
    resolve_character_for_episode,
    get_all_scenes,
    get_shot_by_id,
)
from recoil.execution.asset_manager import AssetManager, ReferenceImage
from recoil.pipeline._lib.prompt_engine import (
    GridType,
    build_cinematic_prompt,
    build_grid_prompt,
    build_two_character_prompt,
)
from recoil.core.model_profiles import get_model, get_cost
from recoil.execution.assembler import PromptPackage

# CP-2 Phase 2 (spec-review edit #6, locked 2026-04-25): the legacy
# `from lib.api_client import get_client` line was retired here. The new
# google keyframe adapter is imported below so the registry path is reachable
# from this module (keyframe consumers are rewired to it in Phase 6).
# Until then `get_client()` continues to dispatch keyframe/previz/video work,
# but it is now sourced from `execution.api_client` directly rather than the
# `lib.api_client` proxy — the load-bearing constraint is "no
# `from lib.api_client import get_client` on pipeline.py:51 after Phase 2."
from recoil.execution.providers.google import GoogleAdapter as _GoogleKeyframeAdapter  # noqa: F401  (Phase 6 wires consumers to this)
from recoil.execution.api_client import get_client
from recoil.pipeline._lib.exceptions import BudgetExceededError
from orchestrator.cost_tracker import CostTracker as CanonicalCostTracker
from orchestrator.manifest import EpisodeLog
from orchestrator.scene_planner import (
    plan_episode,
    route_shot,
    partition_long_scene,
)

# ── Constants ─────────────────────────────────────────────────────────

from recoil.core.paths import (
    CONFIG_PATH,
    get_config,
    DEFAULT_PROJECT,
    ProjectPaths,
    projects_root,
)
from recoil.execution.execution_store import ExecutionStore
from recoil.execution.step_runner import StepRunner, make_identity_gate
from recoil.pipeline.core.cost import read_cost_from_record_safe
# Episode-scoped variant — aliased to avoid colliding with the project-scoped
# ProjectPaths imported above from recoil.core.paths. The two are intentionally
# distinct: core.paths.ProjectPaths is project-scoped, execution.step_types
# .ProjectPaths.for_episode() returns a per-episode bundle (frames_dir,
# video_dir, ...).
from recoil.execution.step_types import ProjectPaths as EpisodeProjectPaths

from recoil.pipeline.core.dispatch import dispatch
from recoil.pipeline.core.dispatch_context import DispatchContext
from recoil.pipeline.core.cost import read_cost_from_result

logger = logging.getLogger("starsend.pipeline")


def _shot_id(shot: dict):
    """Get shot identifier from either plan or legacy format."""
    return shot.get("shot_id", shot.get("id", 0))


def _shot_name(shot: dict) -> str:
    """Get shot display name from either format."""
    if "shot_id" in shot:
        return shot["shot_id"]
    return shot.get("name", f"shot_{shot.get('id', 0)}")


# ── PipelineContext — shared state for strategies ─────────────────────


@dataclass
class PipelineContext:
    """Shared state passed to pipeline strategies."""

    episode: int
    project: str
    dry_run: bool
    storyboard: dict
    breakdown: dict
    project_config: dict
    starsend_config: dict
    asset_manager: AssetManager
    cost_tracker: CanonicalCostTracker
    log: EpisodeLog
    assembler: "ShotAssembler"
    character_data: dict = field(default_factory=dict)
    scene_refs: dict = field(default_factory=dict)  # scene_index -> Path
    execution_store: Optional[ExecutionStore] = None
    step_runner: Optional["StepRunner"] = None
    output_dir: Optional[Path] = None
    video_dir: Optional[Path] = None
    grids_dir: Optional[Path] = None
    panels_dir: Optional[Path] = None
    scene_refs_dir: Optional[Path] = None
    budget_cap: float = 0.0


# ── PipelineStrategy ABC ─────────────────────────────────────────────


class PipelineStrategy(ABC):
    """Abstract base for sub-pipeline strategies."""

    @abstractmethod
    def execute(self, shot: dict, context: PipelineContext) -> dict:
        """Execute the pipeline for a single shot.

        Returns a result dict with at minimum: shot_id, status, tier.
        """
        ...


# ── StillPipeline — handles simple/standard/complex tiers ────────────


class StillPipeline(PipelineStrategy):
    """3-pass image generation pipeline.

    Handles all three still tiers internally:
      simple   — Direct Pro render (1 API call)
      standard — Flash grid → hero extraction → Pro final render
      complex  — Flash grid + expression refs + multiple Pro candidates
    """

    def execute(self, shot: dict, context: PipelineContext) -> dict:
        tier = shot.get("_tier", "standard")
        if tier == "simple":
            return self._run_simple(shot, context)
        elif tier == "complex":
            return self._run_complex(shot, context)
        else:
            return self._run_standard(shot, context)

    @staticmethod
    def _extract_refs_from_package(package: PromptPackage):
        """Extract identity and expression ref paths from a compiled package."""
        identity_ref_paths = []
        expression_ref_paths = []
        for ref in package.references:
            if not ref.path.exists():
                continue
            if ref.ref_type == "identity":
                identity_ref_paths.append(ref.path)
            elif ref.ref_type == "expression":
                expression_ref_paths.append(ref.path)
        return identity_ref_paths, expression_ref_paths

    @staticmethod
    def _build_gates(identity_ref_paths: list[Path]) -> list:
        """Build gate functions for StepRunner from available refs."""
        gates = []
        if identity_ref_paths:
            gates.append(make_identity_gate(ref_paths=identity_ref_paths))
        return gates

    def _run_simple(self, shot: dict, context: PipelineContext) -> dict:
        """Direct Pro render. No grid planning."""
        shot_id = shot.get("id", 0)
        scene_index = shot.get("_scene_index", 0)
        scene_ref = context.scene_refs.get(scene_index)

        package = context.assembler.compile_production_package(
            shot=shot,
            storyboard=context.storyboard,
            episode=context.episode,
            scene_ref_path=scene_ref,
            character_data=context.character_data,
        )

        if context.dry_run:
            print(package.describe())
            return {"shot_id": shot_id, "status": "dry_run", "tier": "simple"}

        # Budget check
        _check_budget(context, package.estimated_cost())

        # ── StepRunner path (unified state tracking) ──
        if context.step_runner:
            identity_refs, expression_refs = self._extract_refs_from_package(package)
            gates = self._build_gates(identity_refs)

            # Build inputs snapshot
            from recoil.pipeline._lib.take_inputs import build_inputs_snapshot
            from recoil.pipeline._lib.prompt_engine import (
                build_prompt_sections_from_plan,
            )

            prompt_sections = build_prompt_sections_from_plan(
                shot,
                context.storyboard.get("bible", {}),
                context.project_config,
                episode=context.episode,
            )
            refs_sent = (
                [
                    {
                        "type": "location",
                        "id": "scene_ref",
                        "url": str(scene_ref) if scene_ref else "",
                        "label": "Scene Ref",
                        "sent_to_model": bool(scene_ref),
                    }
                ]
                + [
                    {
                        "type": "character",
                        "id": str(r.path.stem),
                        "url": str(r.path),
                        "label": r.path.stem,
                        "sent_to_model": True,
                    }
                    for r in package.references
                    if r.ref_type == "identity"
                ]
                + [
                    {
                        "type": "expression",
                        "id": str(r.path.stem),
                        "url": str(r.path),
                        "label": r.path.stem,
                        "sent_to_model": True,
                    }
                    for r in package.references
                    if r.ref_type == "expression"
                ]
            )
            inputs_snap = build_inputs_snapshot(
                prompt_sections=prompt_sections,
                routing={
                    "pipeline": "still",
                    "model": package.model,
                    "tier": "simple",
                    "reason": "direct Pro render",
                },
                refs_sent=refs_sent,
                bible=context.storyboard.get("bible", {}),
                project_config=context.project_config,
                generation_params={
                    "aspect_ratio": package.aspect_ratio,
                    "model": package.model,
                },
                builder_name="still_simple_pro",
            )

            ctx = DispatchContext(
                caller_id="pipeline_orchestrator",
                step_runner=context.step_runner,
                project=context.project,
                episode=context.episode,
            )
            receipt = dispatch(
                "image_t2i",
                {
                    "shot_id": str(shot_id),
                    "prompt": package.prompt_text,
                    "model": package.model,
                    "scene_ref_path": scene_ref,
                    "identity_refs": identity_refs,
                    "expression_refs": expression_refs,
                    "aspect_ratio": package.aspect_ratio,
                    "gates": gates,
                    "inputs_snapshot": inputs_snap,
                },
                context=ctx,
            )
            result = receipt.run_result

            # Record in CLI cost tracker (separate from ExecutionStore cost)
            shot_name = shot.get("name", f"shot_{shot_id}")
            context.cost_tracker.record(
                shot_id,
                shot_name,
                package.model,
                package.aspect_ratio,
                package.image_size,
                read_cost_from_result(result),
                "simple",
                "pro",
            )

            # Scene ref propagation (if this shot provides one)
            if shot.get("_provides_scene_ref") and result.output_path:
                output_abs = Path(result.output_path)
                if not output_abs.is_absolute():
                    # output_path from StepRunner is project-relative
                    from recoil.execution.step_types import ProjectPaths

                    paths = ProjectPaths.for_episode(context.project, context.episode)
                    output_abs = paths.project_root / result.output_path
                if output_abs.exists():
                    ref_path = context.scene_refs_dir / f"scene_{scene_index}_ref.png"
                    ref_path.write_bytes(output_abs.read_bytes())
                    context.scene_refs[scene_index] = ref_path
                    logger.info(f"  Saved scene ref: {ref_path.name}")

            return {
                "shot_id": shot_id,
                "status": "ok"
                if result.success
                else (result.metadata.get("final_state") or ""),
                "tier": "simple",
                "output": result.output_path or "",
                "cost": read_cost_from_result(result),
                "take_id": result.metadata.get("take_id"),
            }

        # ── Fallback path (no StepRunner — dry_run or no store) ──
        image_data = _generate_image(package, context)
        if image_data is None:
            return {
                "shot_id": shot_id,
                "status": "error",
                "error": "Generation returned no image",
            }

        cost = get_cost(package.model)
        shot_name = shot.get("name", f"shot_{shot_id}")
        context.cost_tracker.record(
            shot_id,
            shot_name,
            package.model,
            package.aspect_ratio,
            package.image_size,
            cost,
            "simple",
            "pro",
        )

        output_path = _save_output(image_data, shot, context)

        if shot.get("_provides_scene_ref"):
            ref_path = context.scene_refs_dir / f"scene_{scene_index}_ref.png"
            ref_path.write_bytes(image_data)
            context.scene_refs[scene_index] = ref_path
            logger.info(f"  Saved scene ref: {ref_path.name}")

        return {
            "shot_id": shot_id,
            "status": "ok",
            "tier": "simple",
            "output": str(output_path),
            "cost": cost,
        }

    def _run_standard(self, shot: dict, context: PipelineContext) -> dict:
        """Full 3-pass: Flash grid → hero extraction → Pro final render."""
        shot_id = shot.get("id", 0)
        scene_index = shot.get("_scene_index", 0)
        scene_ref = context.scene_refs.get(scene_index)

        # ── Pass 1: Flash grid (stays in StillPipeline) ──
        grid_package = context.assembler.compile_grid_package(
            shots=[shot],
            storyboard=context.storyboard,
            scene_index=scene_index,
            episode=context.episode,
            character_data=context.character_data,
        )

        if context.dry_run:
            print("--- PASS 1: Flash Grid ---")
            print(grid_package.describe())

        grid_data = None
        ep_prefix = _ep_prefix(context)

        if not context.dry_run:
            _check_budget(context, grid_package.estimated_cost())
            grid_data = _generate_image(grid_package, context)
            if grid_data:
                grid_cost = get_cost(grid_package.model)
                shot_name = shot.get("name", f"shot_{shot_id}")
                context.cost_tracker.record(
                    shot_id,
                    shot_name,
                    grid_package.model,
                    grid_package.aspect_ratio,
                    grid_package.image_size,
                    grid_cost,
                    "standard",
                    "grid",
                )
                grid_path = (
                    context.grids_dir / f"{ep_prefix}_S{scene_index:02d}_grid.png"
                )
                grid_path.write_bytes(grid_data)
                logger.info(f"  Grid saved: {grid_path.name}")

        # ── Pass 2: Hero extraction (stays in StillPipeline) ──
        hero_path = None
        if grid_data and not context.dry_run:
            grid_type = GridType.DIRECTORS_TAKE
            panels = _split_grid(grid_data, grid_type)
            if panels:
                for idx, panel_bytes in enumerate(panels):
                    panel_path = (
                        context.panels_dir
                        / f"{ep_prefix}_S{scene_index:02d}_f{idx + 1}.png"
                    )
                    panel_path.write_bytes(panel_bytes)

                hero_idx = len(panels) // 2 if len(panels) >= 9 else 0
                hero_path = (
                    context.panels_dir
                    / f"{ep_prefix}_S{scene_index:02d}_f{hero_idx + 1}.png"
                )
                logger.info(f"  Hero panel: {hero_path.name}")

        if context.dry_run:
            print("--- PASS 2: Hero Extraction (simulated) ---")
            print("  Would extract best panel from grid")

        # ── Pass 3: Pro final render ──
        pro_package = context.assembler.compile_production_package(
            shot=shot,
            storyboard=context.storyboard,
            episode=context.episode,
            scene_ref_path=scene_ref,
            pose_ref_path=hero_path,
            character_data=context.character_data,
        )

        if context.dry_run:
            print("--- PASS 3: Pro Render ---")
            print(pro_package.describe())
            return {"shot_id": shot_id, "status": "dry_run", "tier": "standard"}

        _check_budget(context, pro_package.estimated_cost())

        # ── StepRunner path for Pass 3 (unified state tracking) ──
        if context.step_runner:
            identity_refs, expression_refs = self._extract_refs_from_package(
                pro_package
            )
            gates = self._build_gates(identity_refs)

            # Build inputs snapshot
            from recoil.pipeline._lib.take_inputs import build_inputs_snapshot
            from recoil.pipeline._lib.prompt_engine import (
                build_prompt_sections_from_plan,
            )

            prompt_sections = build_prompt_sections_from_plan(
                shot,
                context.storyboard.get("bible", {}),
                context.project_config,
                episode=context.episode,
            )
            refs_sent = (
                [
                    {
                        "type": "location",
                        "id": "scene_ref",
                        "url": str(scene_ref) if scene_ref else "",
                        "label": "Scene Ref",
                        "sent_to_model": bool(scene_ref),
                    }
                ]
                + [
                    {
                        "type": "character",
                        "id": str(r.path.stem),
                        "url": str(r.path),
                        "label": r.path.stem,
                        "sent_to_model": True,
                    }
                    for r in pro_package.references
                    if r.ref_type == "identity"
                ]
                + [
                    {
                        "type": "expression",
                        "id": str(r.path.stem),
                        "url": str(r.path),
                        "label": r.path.stem,
                        "sent_to_model": True,
                    }
                    for r in pro_package.references
                    if r.ref_type == "expression"
                ]
            )
            inputs_snap = build_inputs_snapshot(
                prompt_sections=prompt_sections,
                routing={
                    "pipeline": "still",
                    "model": pro_package.model,
                    "tier": "standard",
                    "reason": "3-pass standard render",
                },
                refs_sent=refs_sent,
                bible=context.storyboard.get("bible", {}),
                project_config=context.project_config,
                generation_params={
                    "aspect_ratio": pro_package.aspect_ratio,
                    "model": pro_package.model,
                },
                builder_name="still_standard_pro",
            )

            ctx = DispatchContext(
                caller_id="pipeline_orchestrator",
                step_runner=context.step_runner,
                project=context.project,
                episode=context.episode,
            )
            receipt = dispatch(
                "image_t2i",
                {
                    "shot_id": str(shot_id),
                    "prompt": pro_package.prompt_text,
                    "model": pro_package.model,
                    "scene_ref_path": scene_ref,
                    "pose_ref_path": hero_path,
                    "identity_refs": identity_refs,
                    "expression_refs": expression_refs,
                    "aspect_ratio": pro_package.aspect_ratio,
                    "gates": gates,
                    "inputs_snapshot": inputs_snap,
                },
                context=ctx,
            )
            result = receipt.run_result

            # Record in CLI cost tracker
            shot_name = shot.get("name", f"shot_{shot_id}")
            context.cost_tracker.record(
                shot_id,
                shot_name,
                pro_package.model,
                pro_package.aspect_ratio,
                pro_package.image_size,
                read_cost_from_result(result),
                "standard",
                "pro",
            )

            # Scene ref propagation
            if shot.get("_provides_scene_ref") and result.output_path:
                output_abs = Path(result.output_path)
                if not output_abs.is_absolute():
                    from recoil.execution.step_types import ProjectPaths

                    paths = ProjectPaths.for_episode(context.project, context.episode)
                    output_abs = paths.project_root / result.output_path
                if output_abs.exists():
                    ref_path = context.scene_refs_dir / f"scene_{scene_index}_ref.png"
                    ref_path.write_bytes(output_abs.read_bytes())
                    context.scene_refs[scene_index] = ref_path

            return {
                "shot_id": shot_id,
                "status": "ok"
                if result.success
                else (result.metadata.get("final_state") or ""),
                "tier": "standard",
                "output": result.output_path or "",
                "cost": read_cost_from_result(result),
                "take_id": result.metadata.get("take_id"),
            }

        # ── Fallback path (no StepRunner) ──
        image_data = _generate_image(pro_package, context)
        if image_data is None:
            return {
                "shot_id": shot_id,
                "status": "error",
                "error": "Pro render returned no image",
            }

        pro_cost = get_cost(pro_package.model)
        shot_name = shot.get("name", f"shot_{shot_id}")
        context.cost_tracker.record(
            shot_id,
            shot_name,
            pro_package.model,
            pro_package.aspect_ratio,
            pro_package.image_size,
            pro_cost,
            "standard",
            "pro",
        )

        output_path = _save_output(image_data, shot, context)

        if shot.get("_provides_scene_ref"):
            ref_path = context.scene_refs_dir / f"scene_{scene_index}_ref.png"
            ref_path.write_bytes(image_data)
            context.scene_refs[scene_index] = ref_path

        return {
            "shot_id": shot_id,
            "status": "ok",
            "tier": "standard",
            "output": str(output_path),
            "cost": pro_cost,
        }

    def _run_complex(self, shot: dict, context: PipelineContext) -> dict:
        """3-pass + expression refs + multiple Pro candidates."""
        shot_id = shot.get("id", 0)
        scene_index = shot.get("_scene_index", 0)
        scene_ref = context.scene_refs.get(scene_index)

        # ── Pass 1: Flash grid (stays in StillPipeline) ──
        grid_package = context.assembler.compile_grid_package(
            shots=[shot],
            storyboard=context.storyboard,
            scene_index=scene_index,
            episode=context.episode,
            character_data=context.character_data,
        )

        if context.dry_run:
            print("--- PASS 1: Flash Grid (complex) ---")
            print(grid_package.describe())

        grid_data = None
        ep_prefix = _ep_prefix(context)

        if not context.dry_run:
            _check_budget(context, grid_package.estimated_cost())
            grid_data = _generate_image(grid_package, context)
            if grid_data:
                grid_cost = get_cost(grid_package.model)
                shot_name = shot.get("name", f"shot_{shot_id}")
                context.cost_tracker.record(
                    shot_id,
                    shot_name,
                    grid_package.model,
                    grid_package.aspect_ratio,
                    grid_package.image_size,
                    grid_cost,
                    "complex",
                    "grid",
                )
                grid_path = (
                    context.grids_dir / f"{ep_prefix}_S{scene_index:02d}_grid.png"
                )
                grid_path.write_bytes(grid_data)
                logger.info(f"  Grid saved: {grid_path.name}")

        # ── Pass 2: Hero extraction (stays in StillPipeline) ──
        hero_path = None
        if grid_data and not context.dry_run:
            grid_type = GridType.DIRECTORS_TAKE
            panels = _split_grid(grid_data, grid_type)
            if panels:
                for idx, panel_bytes in enumerate(panels):
                    panel_path = (
                        context.panels_dir
                        / f"{ep_prefix}_S{scene_index:02d}_f{idx + 1}.png"
                    )
                    panel_path.write_bytes(panel_bytes)

                hero_idx = len(panels) // 2 if len(panels) >= 9 else 0
                hero_path = (
                    context.panels_dir
                    / f"{ep_prefix}_S{scene_index:02d}_f{hero_idx + 1}.png"
                )
                logger.info(f"  Hero panel: {hero_path.name}")

        if context.dry_run:
            print("--- PASS 2: Hero Extraction (simulated, complex) ---")

        # ── Pass 3: Pro render (multiple candidates via gate retry) ──
        num_candidates = 3
        pro_package = context.assembler.compile_production_package(
            shot=shot,
            storyboard=context.storyboard,
            episode=context.episode,
            scene_ref_path=scene_ref,
            pose_ref_path=hero_path,
            character_data=context.character_data,
            num_candidates=num_candidates,
        )

        if context.dry_run:
            print(f"--- PASS 3: Pro Render (complex, {num_candidates} candidates) ---")
            print(pro_package.describe())
            return {"shot_id": shot_id, "status": "dry_run", "tier": "complex"}

        # ── StepRunner path for Pass 3 (unified state tracking) ──
        # The gate retry loop in execute_keyframe replaces the manual
        # multi-candidate loop — max_gate_retries maps to num_candidates.
        if context.step_runner:
            identity_refs, expression_refs = self._extract_refs_from_package(
                pro_package
            )
            gates = self._build_gates(identity_refs)

            # Build inputs snapshot
            from recoil.pipeline._lib.take_inputs import build_inputs_snapshot
            from recoil.pipeline._lib.prompt_engine import (
                build_prompt_sections_from_plan,
            )

            prompt_sections = build_prompt_sections_from_plan(
                shot,
                context.storyboard.get("bible", {}),
                context.project_config,
                episode=context.episode,
            )
            refs_sent = (
                [
                    {
                        "type": "location",
                        "id": "scene_ref",
                        "url": str(scene_ref) if scene_ref else "",
                        "label": "Scene Ref",
                        "sent_to_model": bool(scene_ref),
                    }
                ]
                + [
                    {
                        "type": "character",
                        "id": str(r.path.stem),
                        "url": str(r.path),
                        "label": r.path.stem,
                        "sent_to_model": True,
                    }
                    for r in pro_package.references
                    if r.ref_type == "identity"
                ]
                + [
                    {
                        "type": "expression",
                        "id": str(r.path.stem),
                        "url": str(r.path),
                        "label": r.path.stem,
                        "sent_to_model": True,
                    }
                    for r in pro_package.references
                    if r.ref_type == "expression"
                ]
            )
            inputs_snap = build_inputs_snapshot(
                prompt_sections=prompt_sections,
                routing={
                    "pipeline": "still",
                    "model": pro_package.model,
                    "tier": "complex",
                    "reason": "3-pass complex render",
                },
                refs_sent=refs_sent,
                bible=context.storyboard.get("bible", {}),
                project_config=context.project_config,
                generation_params={
                    "aspect_ratio": pro_package.aspect_ratio,
                    "model": pro_package.model,
                },
                builder_name="still_complex_pro",
            )

            ctx = DispatchContext(
                caller_id="pipeline_orchestrator",
                step_runner=context.step_runner,
                project=context.project,
                episode=context.episode,
            )
            receipt = dispatch(
                "image_t2i",
                {
                    "shot_id": str(shot_id),
                    "prompt": pro_package.prompt_text,
                    "model": pro_package.model,
                    "scene_ref_path": scene_ref,
                    "pose_ref_path": hero_path,
                    "identity_refs": identity_refs,
                    "expression_refs": expression_refs,
                    "aspect_ratio": pro_package.aspect_ratio,
                    "gates": gates,
                    "max_gate_retries": num_candidates,
                    "inputs_snapshot": inputs_snap,
                },
                context=ctx,
            )
            result = receipt.run_result

            # Record in CLI cost tracker
            shot_name = shot.get("name", f"shot_{shot_id}")
            context.cost_tracker.record(
                shot_id,
                shot_name,
                pro_package.model,
                pro_package.aspect_ratio,
                pro_package.image_size,
                read_cost_from_result(result),
                "complex",
                "pro",
            )

            # Scene ref propagation
            if shot.get("_provides_scene_ref") and result.output_path:
                output_abs = Path(result.output_path)
                if not output_abs.is_absolute():
                    from recoil.execution.step_types import ProjectPaths

                    paths = ProjectPaths.for_episode(context.project, context.episode)
                    output_abs = paths.project_root / result.output_path
                if output_abs.exists():
                    ref_path = context.scene_refs_dir / f"scene_{scene_index}_ref.png"
                    ref_path.write_bytes(output_abs.read_bytes())
                    context.scene_refs[scene_index] = ref_path

            return {
                "shot_id": shot_id,
                "status": "ok"
                if result.success
                else (result.metadata.get("final_state") or ""),
                "tier": "complex",
                "output": result.output_path or "",
                "cost": read_cost_from_result(result),
                "candidates_generated": num_candidates,
                "take_id": result.metadata.get("take_id"),
            }

        # ── Fallback path (no StepRunner) ──
        best_image = None
        total_pro_cost = 0.0

        for candidate_idx in range(num_candidates):
            _check_budget(context, pro_package.estimated_cost())
            image_data = _generate_image(pro_package, context)
            pro_cost = get_cost(pro_package.model)
            shot_name = shot.get("name", f"shot_{shot_id}")
            context.cost_tracker.record(
                shot_id,
                shot_name,
                pro_package.model,
                pro_package.aspect_ratio,
                pro_package.image_size,
                pro_cost,
                "complex",
                f"pro_candidate_{candidate_idx + 1}",
            )
            total_pro_cost += pro_cost

            if image_data:
                candidate_path = _save_output(
                    image_data, shot, context, suffix=f"_c{candidate_idx + 1}"
                )
                logger.info(f"  Candidate {candidate_idx + 1}: {candidate_path.name}")

                if best_image is None:
                    best_image = image_data

        if best_image is None:
            return {
                "shot_id": shot_id,
                "status": "error",
                "error": "All Pro candidates failed",
            }

        output_path = _save_output(best_image, shot, context)

        if shot.get("_provides_scene_ref"):
            ref_path = context.scene_refs_dir / f"scene_{scene_index}_ref.png"
            ref_path.write_bytes(best_image)
            context.scene_refs[scene_index] = ref_path

        return {
            "shot_id": shot_id,
            "status": "ok",
            "tier": "complex",
            "output": str(output_path),
            "cost": total_pro_cost,
            "candidates_generated": num_candidates,
        }


def _run_gate_3_on_video(video_path, context: PipelineContext, ref_paths=None):
    """Run Gate 3 (video drift). DEFERRED on drift or API failure."""
    try:
        from recoil.pipeline._lib.validation import Validator

        validator = Validator()
        g3 = validator.run_gate_3(video_path, ref_paths or [])
        deferred = g3.details.get("deferred", False)
        return {
            "passed": True,
            "deferred": deferred,
            "details": g3.details,
            "cost": g3.cost,
            "flagged_for_review": deferred,
        }
    except ImportError:
        return {
            "passed": True,
            "deferred": True,
            "cost": 0.0,
            "flagged_for_review": True,
        }
    except Exception as e:
        logger.error("Gate 3 DEFERRED (fail-open-but-flag): %s", e)
        return {
            "passed": True,
            "deferred": True,
            "cost": 0.0,
            "api_error": str(e),
            "flagged_for_review": True,
        }


def _generate_with_gate_retry(
    generate_fn,
    context: PipelineContext,
    shot: dict,
    shot_id: str,
    max_attempts: int = 3,
    ref_paths: list | None = None,
    prompt_skeleton: dict | None = None,
    skip_gate_2: bool = False,
):
    """Generate an image with sequential short-circuit gate retry.

    Gate 1 (mechanical) → auto-retry on fail (up to max_attempts).
    Gate 2 (semantic) → escalate immediately on fail (no retry).
    API errors → fail-closed, halt the shot.

    Args:
        generate_fn: Callable(shot, context) that returns dict with
                     "output" (path str), "cost" (float), "status" (str).
        context: PipelineContext.
        shot: Shot dict.
        shot_id: Shot ID string.
        max_attempts: Max generation attempts for mechanical failures.
        ref_paths: Character ref paths for Gate 2.
        prompt_skeleton: Prompt skeleton for Gate 2.
        skip_gate_2: If True, skip semantic check (for QUICK RETRY mode).

    Returns:
        dict with keys: output, cost, retry_waste_cost, gate_results, status,
                        attempts, failed_gate (if failed).
    """
    total_cost = 0.0
    waste_cost = 0.0
    last_gate_results = {}

    for attempt in range(1, max_attempts + 1):
        is_retry = attempt > 1

        # Step 1: Generate
        gen_result = generate_fn(shot, context)
        gen_cost = read_cost_from_record_safe(gen_result)

        if gen_result.get("status") in ("error", "failed"):
            total_cost += gen_cost
            if is_retry:
                waste_cost += gen_cost
            return {
                "status": "failed",
                "error": gen_result.get("error", "Generation failed"),
                "cost": total_cost,
                "retry_waste_cost": waste_cost,
                "gate_results": last_gate_results,
                "attempts": attempt,
            }

        image_path = Path(gen_result.get("output", ""))
        if not image_path.exists():
            total_cost += gen_cost
            if is_retry:
                waste_cost += gen_cost
            continue  # Try again

        # Step 2: Run Gate 1 (mechanical) — always runs
        try:
            from recoil.pipeline._lib.validation import Validator

            validator = Validator()
            g1 = validator.run_gate_1_image(image_path)
            gate_1_result = {
                "passed": g1.passed,
                "details": g1.details,
                "cost": g1.cost,
            }
        except ImportError:
            gate_1_result = {"passed": True, "details": {}, "cost": 0.0}
        except Exception as e:
            # Fail-closed on API error
            logger.error("Gate 1 API error (fail-closed) for %s: %s", shot_id, e)
            total_cost += gen_cost
            return {
                "status": "failed",
                "error": f"Gate API unreachable: {e}",
                "cost": total_cost,
                "retry_waste_cost": waste_cost,
                "gate_results": {"gate_1": {"passed": False, "api_error": str(e)}},
                "attempts": attempt,
                "failed_gate": "api_error",
            }

        attempt_cost = gen_cost + read_cost_from_record_safe(gate_1_result)

        if not gate_1_result["passed"]:
            # Mechanical failure — auto-retry
            logger.warning(
                "Gate 1 MECHANICAL FAIL for %s (attempt %d/%d): %s",
                shot_id,
                attempt,
                max_attempts,
                gate_1_result.get("details", {}),
            )
            total_cost += attempt_cost
            waste_cost += attempt_cost
            last_gate_results["gate_1"] = gate_1_result
            last_gate_results["last_attempt"] = attempt

            # Clean up failed image to prevent disk bloat
            try:
                image_path.unlink(missing_ok=True)
            except OSError:
                pass

            if attempt < max_attempts:
                continue  # Short-circuit: DO NOT run Gate 2, loop back to generate
            else:
                # Exhausted attempts
                return {
                    "status": "mechanical_failed",
                    "cost": total_cost,
                    "retry_waste_cost": waste_cost,
                    "gate_results": last_gate_results,
                    "attempts": attempt,
                    "failed_gate": "gate_1",
                }

        # Gate 1 passed — now run Gate 2 (semantic) if not skipped
        gate_2_result = None
        if not skip_gate_2 and ref_paths:
            try:
                g2 = validator.run_gate_2(image_path, ref_paths, prompt_skeleton)
                gate_2_result = {
                    "passed": g2.passed,
                    "details": g2.details,
                    "cost": g2.cost,
                }
            except ImportError:
                gate_2_result = {"passed": True, "details": {}, "cost": 0.0}
            except Exception as e:
                logger.error("Gate 2 API error (fail-closed) for %s: %s", shot_id, e)
                total_cost += attempt_cost
                return {
                    "status": "failed",
                    "error": f"Gate 2 API unreachable: {e}",
                    "cost": total_cost,
                    "retry_waste_cost": waste_cost,
                    "gate_results": {
                        "gate_1": gate_1_result,
                        "gate_2": {"api_error": str(e)},
                    },
                    "attempts": attempt,
                    "failed_gate": "api_error",
                }

            attempt_cost += read_cost_from_record_safe(gate_2_result)

            if not gate_2_result["passed"]:
                # Semantic failure — DO NOT retry, escalate immediately
                logger.warning(
                    "Gate 2 SEMANTIC FAIL for %s: %s — escalating to workbench",
                    shot_id,
                    gate_2_result.get("details", {}),
                )
                total_cost += attempt_cost
                return {
                    "status": "semantic_failed",
                    "output": str(image_path),
                    "cost": total_cost,
                    "retry_waste_cost": waste_cost,
                    "gate_results": {"gate_1": gate_1_result, "gate_2": gate_2_result},
                    "attempts": attempt,
                    "failed_gate": "gate_2",
                }

        # Both gates passed — success
        total_cost += attempt_cost
        all_gates = {"gate_1": gate_1_result}
        if gate_2_result:
            all_gates["gate_2"] = gate_2_result

        return {
            "status": "ok",
            "output": str(image_path),
            "cost": total_cost,
            "retry_waste_cost": waste_cost,
            "gate_results": all_gates,
            "attempts": attempt,
        }

    # Safety fallback (should not reach here)
    return {
        "status": "mechanical_failed",
        "cost": total_cost,
        "retry_waste_cost": waste_cost,
        "gate_results": last_gate_results,
        "attempts": max_attempts,
        "failed_gate": "gate_1",
    }


# ── I2VPipeline ─────────────────────────────────────────────────────


class I2VPipeline(PipelineStrategy):
    """Keyframe → I2V via Kling 3.0.

    Flow: Generate keyframe (NBP) → Gate 1 + Gate 2 → Submit I2V →
          Poll → Gate 1 (video) + Gate 3 → Update execution plan.
    """

    def execute(self, shot: dict, context: PipelineContext) -> dict:
        shot_id = _shot_id(shot)
        model = shot.get("_target_model", get_model("i2v", "video"))

        if context.dry_run:
            return {
                "shot_id": shot_id,
                "status": "dry_run",
                "pipeline": "i2v",
                "model": model,
            }

        still = StillPipeline()

        # ── StepRunner path (unified state tracking) ──
        # Call still.execute() directly — StepRunner.execute_keyframe() has its
        # own internal gate retry loop, so we must NOT wrap it in
        # _generate_with_gate_retry() (that would create nested 3×3 retries).
        if context.step_runner:
            kf_result = still.execute(shot, context)

            kf_status = kf_result.get("status")
            if kf_status not in ("ok",):
                # Propagate keyframe failure with i2v pipeline tag
                return {
                    "shot_id": shot_id,
                    "status": kf_status
                    if kf_status.startswith("keyframe_")
                    else f"keyframe_{kf_status}"
                    if kf_status not in ("error", "failed")
                    else kf_status,
                    "pipeline": "i2v",
                    "error": kf_result.get("error", "Keyframe generation failed"),
                    "cost": read_cost_from_record_safe(kf_result),
                }

            # Resolve keyframe path to absolute before passing to execute_video
            keyframe_path = Path(kf_result.get("output", ""))
            if not keyframe_path.is_absolute():
                keyframe_path = context.step_runner._paths.project_root / keyframe_path
            keyframe_cost = read_cost_from_record_safe(kf_result)
            kf_take_id = kf_result.get("take_id")

            routing = shot.get("routing_data", {})
            action_prompt = (
                shot.get("prompt_data", {})
                .get("prompt_skeleton", {})
                .get("action_line", "")
            )
            duration = routing.get("target_editorial_duration_s", 5)
            aspect_ratio = context.starsend_config.get(
                "production_aspect_ratio", "9:16"
            )

            try:
                context.step_runner.transition(
                    str(shot_id),
                    "keyframe_approved",
                    reason="Auto-approve for I2V pipeline",
                )
            except Exception:
                pass  # May already be in keyframe_approved or compatible state

            try:
                context.step_runner.transition(
                    str(shot_id),
                    "video_pending",
                    reason="I2V video generation",
                )
            except Exception:
                pass  # May already be in video_pending

            # Build inputs snapshot for I2V video call
            from recoil.pipeline._lib.take_inputs import build_inputs_snapshot
            from recoil.pipeline._lib.prompt_engine import (
                build_prompt_sections_from_plan,
            )

            prompt_sections = build_prompt_sections_from_plan(
                shot,
                context.storyboard.get("bible", {}),
                context.project_config,
                episode=context.episode,
            )
            inputs_snap = build_inputs_snapshot(
                prompt_sections=prompt_sections,
                routing={
                    "pipeline": "i2v",
                    "model": model,
                    "tier": "i2v",
                    "reason": "keyframe-to-video",
                },
                refs_sent=[
                    {
                        "type": "keyframe",
                        "id": "start_frame",
                        "url": str(keyframe_path),
                        "label": "Start Frame",
                        "sent_to_model": True,
                    }
                ],
                bible=context.storyboard.get("bible", {}),
                project_config=context.project_config,
                generation_params={
                    "duration": duration,
                    "model": model,
                    "aspect_ratio": aspect_ratio,
                },
                builder_name="video_i2v",
                parent_take_id=kf_take_id,
            )

            # End frame from shot plan (sandwich workflow)
            end_frame_path = None
            end_frame_ref = shot.get("end_frame") or shot.get("end_image_url")
            if end_frame_ref:
                candidate = Path(end_frame_ref)
                if not candidate.is_absolute():
                    candidate = context.output_dir / candidate
                if candidate.exists():
                    end_frame_path = candidate
                    logger.info(
                        "Sandwich workflow: end_frame=%s for shot %s",
                        candidate,
                        shot_id,
                    )
                else:
                    logger.warning(
                        "end_frame specified but not found: %s (shot %s)",
                        candidate,
                        shot_id,
                    )

            ctx = DispatchContext(
                caller_id="pipeline_orchestrator",
                step_runner=context.step_runner,
                project=context.project,
                episode=context.episode,
            )
            receipt = dispatch(
                "video_i2v",
                {
                    "shot_id": str(shot_id),
                    "prompt": action_prompt,
                    "model": model,
                    "start_frame": keyframe_path,
                    "end_frame": end_frame_path,
                    "duration": duration,
                    "aspect_ratio": aspect_ratio,
                    "inputs_snapshot": inputs_snap,
                    "parent_take_id": kf_take_id,
                },
                context=ctx,
            )
            video_result = receipt.run_result

            # Record in CLI cost tracker
            shot_name = shot.get("name", f"shot_{shot_id}")
            video_cost = read_cost_from_result(video_result)
            total_cost = keyframe_cost + video_cost
            context.cost_tracker.record(
                shot_id,
                shot_name,
                model,
                aspect_ratio,
                "video",
                video_cost,
                "i2v",
                "video",
            )

            return {
                "shot_id": shot_id,
                "status": "ok"
                if video_result.success
                else (video_result.metadata.get("final_state") or ""),
                "pipeline": "i2v",
                "output": video_result.output_path or "",
                "cost": total_cost,
            }

        # ── Fallback path (no StepRunner) ──
        # Use _generate_with_gate_retry here because fallback StillPipeline
        # does NOT have internal gate retries.
        retry_result = _generate_with_gate_retry(
            generate_fn=lambda s, c: still.execute(s, c),
            context=context,
            shot=shot,
            shot_id=shot_id,
            max_attempts=3,
            ref_paths=shot.get("_ref_paths"),
            prompt_skeleton=shot.get("prompt_data", {}).get("prompt_skeleton"),
        )

        retry_status = retry_result.get("status")
        if retry_status == "mechanical_failed":
            return {
                "shot_id": shot_id,
                "status": "keyframe_mechanical_failed",
                "pipeline": "i2v",
                "gate_results": retry_result.get("gate_results", {}),
                "cost": read_cost_from_record_safe(retry_result),
                "retry_waste_cost": retry_result.get("retry_waste_cost", 0),
                "attempts": retry_result.get("attempts", 0),
            }
        elif retry_status == "semantic_failed":
            return {
                "shot_id": shot_id,
                "status": "keyframe_semantic_failed",
                "pipeline": "i2v",
                "output": retry_result.get("output"),
                "gate_results": retry_result.get("gate_results", {}),
                "cost": read_cost_from_record_safe(retry_result),
                "retry_waste_cost": retry_result.get("retry_waste_cost", 0),
                "attempts": retry_result.get("attempts", 0),
            }
        elif retry_status != "ok":
            return {
                "shot_id": shot_id,
                "status": retry_status or "failed",
                "pipeline": "i2v",
                "error": retry_result.get("error", "Keyframe generation failed"),
                "cost": read_cost_from_record_safe(retry_result),
                "retry_waste_cost": retry_result.get("retry_waste_cost", 0),
                "attempts": retry_result.get("attempts", 0),
            }

        keyframe_path = Path(retry_result["output"])
        keyframe_cost = read_cost_from_record_safe(retry_result)
        # Phase 11: keyframe_gate_results consumer removed with the tombstone.
        _ = retry_result.get("gate_results", {})

        routing = shot.get("routing_data", {})
        action_prompt = (
            shot.get("prompt_data", {})
            .get("prompt_skeleton", {})
            .get("action_line", "")
        )
        duration = routing.get("target_editorial_duration_s", 5)
        aspect_ratio = context.starsend_config.get("production_aspect_ratio", "9:16")

        try:
            client = get_client(model)
            import base64

            keyframe_b64 = base64.b64encode(keyframe_path.read_bytes()).decode()

            payload = {
                "mode": "image2video",
                "prompt": action_prompt,
                "image": keyframe_b64,
                "duration": duration,
                "aspect_ratio": aspect_ratio,
            }

            job = client.submit(payload)
            result = client.wait_for_job(job, timeout_s=600)

            # Phase 11 tombstone (2026-05-21): legacy I2VPipeline writer
            # disabled. All production calls now go through dispatch_cli /
            # run_overnight (CP-5 entry-point unification). The class body
            # remains so callers see a clear error rather than a missing
            # attribute. The original write/return branches are removed below
            # so the function does not silently fall through.
            del result
            raise NotImplementedError(
                "Legacy I2VPipeline deprecated per CP-5. "
                "Use dispatch_cli or run_overnight."
            )
        except NotImplementedError:
            return {
                "shot_id": shot_id,
                "status": "skipped",
                "pipeline": "i2v",
                "reason": f"{model} not yet available",
            }


# ── T2VPipeline ──────────────────────────────────────────────────────


class T2VPipeline(PipelineStrategy):
    """Text-to-video via Kling/SeedDance/Veo.

    Flow: Submit T2V → Poll → Gate 1 (video) + Gate 3 → Update execution plan.
    """

    def execute(self, shot: dict, context: PipelineContext) -> dict:
        shot_id = _shot_id(shot)
        model = shot.get("_target_model", get_model("t2v_default", "video"))

        if context.dry_run:
            return {
                "shot_id": shot_id,
                "status": "dry_run",
                "pipeline": "t2v",
                "model": model,
            }

        # Prefer pre-compiled veo_t2v prompt from plan, fall back to builder
        compiled = shot.get("compiled_prompts", {})
        prompt = compiled.get("veo_t2v", "")
        if not prompt:
            from recoil.pipeline._lib.prompt_engine import build_kling_t2v_prompt

            prompt = build_kling_t2v_prompt(shot)

        routing = shot.get("routing_data", {})
        duration = routing.get("target_editorial_duration_s", 5)
        aspect_ratio = context.starsend_config.get("production_aspect_ratio", "9:16")

        # ── StepRunner path (unified state tracking) ──
        if context.step_runner:
            try:
                context.step_runner.transition(
                    str(shot_id),
                    "video_pending",
                    reason="T2V video generation",
                )
            except Exception:
                pass  # May already be in video_pending or compatible state

            # Build inputs snapshot for T2V video call
            from recoil.pipeline._lib.take_inputs import build_inputs_snapshot
            from recoil.pipeline._lib.prompt_engine import (
                build_prompt_sections_from_plan,
            )

            prompt_sections = build_prompt_sections_from_plan(
                shot,
                context.storyboard.get("bible", {}),
                context.project_config,
                episode=context.episode,
            )
            inputs_snap = build_inputs_snapshot(
                prompt_sections=prompt_sections,
                routing={
                    "pipeline": "t2v",
                    "model": model,
                    "tier": "t2v",
                    "reason": "text-to-video",
                },
                refs_sent=[],
                bible=context.storyboard.get("bible", {}),
                project_config=context.project_config,
                generation_params={
                    "duration": duration,
                    "model": model,
                    "aspect_ratio": aspect_ratio,
                },
                builder_name="video_t2v",
            )

            # End frame from shot plan (sandwich workflow)
            end_frame_path_t2v = None
            end_frame_ref_t2v = shot.get("end_frame") or shot.get("end_image_url")
            if end_frame_ref_t2v:
                candidate = Path(end_frame_ref_t2v)
                if not candidate.is_absolute():
                    candidate = context.output_dir / candidate
                if candidate.exists():
                    end_frame_path_t2v = candidate
                    logger.info(
                        "Sandwich workflow: end_frame=%s for shot %s",
                        candidate,
                        shot_id,
                    )
                else:
                    logger.warning(
                        "end_frame specified but not found: %s (shot %s)",
                        candidate,
                        shot_id,
                    )

            ctx = DispatchContext(
                caller_id="pipeline_orchestrator",
                step_runner=context.step_runner,
                project=context.project,
                episode=context.episode,
            )
            receipt = dispatch(
                "video_i2v",
                {
                    "shot_id": str(shot_id),
                    "prompt": prompt,
                    "model": model,
                    # No start_frame → T2V mode
                    "end_frame": end_frame_path_t2v,
                    "duration": duration,
                    "aspect_ratio": aspect_ratio,
                    "inputs_snapshot": inputs_snap,
                },
                context=ctx,
            )
            video_result = receipt.run_result

            # Record in CLI cost tracker
            shot_name = shot.get("name", f"shot_{shot_id}")
            video_cost = read_cost_from_result(video_result)
            context.cost_tracker.record(
                shot_id,
                shot_name,
                model,
                aspect_ratio,
                "video",
                video_cost,
                "t2v",
                "video",
            )

            return {
                "shot_id": shot_id,
                "status": "ok"
                if video_result.success
                else (video_result.metadata.get("final_state") or ""),
                "pipeline": "t2v",
                "output": video_result.output_path or "",
                "cost": video_cost,
            }

        # ── Fallback path (no StepRunner) ──
        try:
            client = get_client(model)

            payload = {
                "mode": "text2video",
                "prompt": prompt,
                "duration": duration,
                "aspect_ratio": aspect_ratio,
            }

            job = client.submit(payload)
            result = client.wait_for_job(job, timeout_s=600)

            # Phase 11 tombstone (2026-05-21): legacy T2VPipeline writer
            # disabled. All production calls now go through dispatch_cli /
            # run_overnight (CP-5 entry-point unification). The class body
            # remains so callers see a clear error rather than a missing
            # attribute.
            del result
            raise NotImplementedError(
                "Legacy T2VPipeline deprecated per CP-5. "
                "Use dispatch_cli or run_overnight."
            )
        except NotImplementedError:
            return {
                "shot_id": shot_id,
                "status": "skipped",
                "pipeline": "t2v",
                "reason": f"{model} not yet available",
            }


# ── MultiShotPipeline ────────────────────────────────────────────────


class MultiShotPipeline(PipelineStrategy):
    """Scene batch via SeedDance 2.0.

    Stub — SeedDance client not yet implemented. Falls back to
    individual T2V for each shot in the group.
    """

    def execute_batch(
        self, shot_group: list[dict], context: PipelineContext
    ) -> list[dict]:
        """Execute for a group of shots (multi-shot scenes)."""
        # Fallback: process individually via T2V
        t2v = T2VPipeline()
        results = []
        for shot in shot_group:
            result = t2v.execute(shot, context)
            result["pipeline"] = "multi_shot_fallback"
            results.append(result)
        return results

    def execute(self, shot: dict, context: PipelineContext) -> dict:
        t2v = T2VPipeline()
        result = t2v.execute(shot, context)
        result["pipeline"] = "multi_shot_fallback"
        return result


# ── Strategy registry ─────────────────────────────────────────────────

_STRATEGIES: dict[str, PipelineStrategy] = {
    "still": StillPipeline(),
    "i2v": I2VPipeline(),
    "t2v": T2VPipeline(),
    "multi_shot": MultiShotPipeline(),
}


def get_strategy(pipeline_name: str) -> PipelineStrategy:
    """Get the strategy for a pipeline name."""
    if pipeline_name not in _STRATEGIES:
        raise ValueError(f"Unknown pipeline: {pipeline_name}")
    return _STRATEGIES[pipeline_name]


# ── ShotAssembler ────────────────────────────────────────────────────


class ShotAssembler:
    """Compiles a shot dict + references into a PromptPackage.

    Handles prompt building, reference selection and ordering,
    and model targeting based on pipeline pass.
    """

    def __init__(
        self,
        breakdown: dict,
        project_config: dict,
        starsend_config: dict,
        asset_manager: AssetManager,
    ):
        self.breakdown = breakdown
        self.project_config = project_config
        self.starsend_config = starsend_config
        self.assets = asset_manager

    def compile_grid_package(
        self,
        shots: list[dict],
        storyboard: dict,
        scene_index: int,
        episode: int,
        character_data: dict = None,
    ) -> PromptPackage:
        """Compile a Flash grid exploration package for a scene's shots.

        Uses the exploration model (Flash) at 1:1 for 3x3 grid generation.
        """
        character_data = character_data or {}
        model = self.starsend_config.get(
            "exploration_model", get_model("exploration", "image")
        )

        # Determine grid type based on shot count
        if len(shots) >= 4:
            grid_type = GridType.SCENE_COVERAGE
            grid_size = "3x3"
        else:
            grid_type = GridType.DIRECTORS_TAKE
            grid_size = "2x2"

        # Build the grid prompt
        prompt_text = build_grid_prompt(
            shots=shots,
            storyboard=storyboard,
            breakdown=self.breakdown,
            project_config=self.project_config,
            grid_type=grid_type,
            character_data=character_data,
            grid_size=grid_size,
        )

        # Build reference list for the grid
        all_refs: list[ReferenceImage] = []
        chars_seen = set()
        max_refs = self.starsend_config.get("max_references_per_shot", 7)

        for shot in shots:
            for char_key in shot.get("characters_in_shot", []):
                if char_key.lower() not in chars_seen:
                    chars_seen.add(char_key.lower())
                    char_paths = get_character_refs(
                        char_key, storyboard.get("project") or DEFAULT_PROJECT
                    )
                    per_char_max = max(1, (max_refs - 1) // max(len(chars_seen), 1))
                    identity_refs = self.assets.get_identity_refs(
                        char_key, char_paths, max_refs=per_char_max
                    )
                    all_refs.extend(identity_refs)

        refs = self.assets.build_shot_refs(
            character_refs=all_refs,
            max_total=max_refs,
        )

        shot_id = shots[0].get("id", 0) if shots else 0

        return PromptPackage(
            shot_id=shot_id,
            prompt_text=prompt_text,
            references=refs,
            model=model,
            aspect_ratio=self.starsend_config.get("grid_aspect_ratio", "1:1"),
            image_size=self.starsend_config.get("default_image_size", "4K"),
            num_candidates=1,
            is_env=all(len(s.get("characters_in_shot", [])) == 0 for s in shots),
        )

    def compile_production_package(
        self,
        shot: dict,
        storyboard: dict,
        episode: int,
        scene_ref_path: Optional[Path] = None,
        pose_ref_path: Optional[Path] = None,
        character_data: dict = None,
        num_candidates: int = 1,
    ) -> PromptPackage:
        """Compile a Pro production render package for a single shot."""
        character_data = character_data or {}
        model = self.starsend_config.get(
            "default_model", get_model("production", "image")
        )
        chars_in_shot = shot.get("characters_in_shot", [])
        is_env = len(chars_in_shot) == 0
        max_refs = self.starsend_config.get("max_references_per_shot", 7)

        # Build prompt
        if len(chars_in_shot) >= 2 and not is_env:
            char_a_key = chars_in_shot[0]
            char_b_key = chars_in_shot[1]
            char_a_data = self._resolve_char_data(char_a_key, episode, character_data)
            char_b_data = self._resolve_char_data(char_b_key, episode, character_data)
            prompt_text = build_two_character_prompt(
                shot=shot,
                storyboard=storyboard,
                char_a_data=char_a_data,
                char_b_data=char_b_data,
                project_config=self.project_config,
            )
        else:
            prompt_text = build_cinematic_prompt(
                shot=shot,
                storyboard=storyboard,
                character_data=character_data,
                project_config=self.project_config,
                is_env=is_env,
            )

        # Build references
        scene_ref = None
        if scene_ref_path and scene_ref_path.exists():
            scene_ref = self.assets.get_scene_ref(scene_ref_path)

        pose_ref = None
        if pose_ref_path and pose_ref_path.exists():
            pose_ref = self.assets.get_pose_ref(pose_ref_path)

        expression_ref = None
        identity_refs: list[ReferenceImage] = []

        if not is_env:
            emotion = shot.get("emotion", "")
            if emotion:
                expression_ref = self.assets.get_expression_ref(emotion)

            for char_key in chars_in_shot:
                char_paths = get_character_refs(
                    char_key, storyboard.get("project") or DEFAULT_PROJECT
                )
                per_char_max = max(1, (max_refs - 2) // max(len(chars_in_shot), 1))
                char_identity = self.assets.get_identity_refs(
                    char_key, char_paths, max_refs=per_char_max
                )
                identity_refs.extend(char_identity)

        refs = self.assets.build_shot_refs(
            character_refs=identity_refs,
            scene_ref=scene_ref,
            pose_ref=pose_ref,
            expression_ref=expression_ref,
            max_total=max_refs,
        )

        directives = []
        if is_env:
            directives.append(
                "CRITICAL: This is an ENVIRONMENT-ONLY shot. "
                "ABSOLUTELY NO PEOPLE in this image."
            )

        return PromptPackage(
            shot_id=shot.get("id", 0),
            prompt_text=prompt_text,
            references=refs,
            model=model,
            aspect_ratio=self.starsend_config.get("production_aspect_ratio", "9:16"),
            image_size=self.starsend_config.get("default_image_size", "4K"),
            num_candidates=num_candidates,
            is_env=is_env,
            directives=directives,
        )

    def _resolve_char_data(
        self, char_key: str, episode: int, character_data: dict
    ) -> dict:
        """Resolve character visual data for two-character prompt building."""
        if char_key.lower() in character_data:
            return character_data[char_key.lower()]

        try:
            from recoil.pipeline._lib.render_schema import validate_handoff

            handoff = validate_handoff(char_key, episode, self.breakdown)
            return {
                "name": handoff.display_name,
                "visual": handoff.visual_description,
                "wardrobe": handoff.wardrobe_description,
                "identity_type": handoff.identity_type,
            }
        except (FileNotFoundError, KeyError, AttributeError):
            # Tenet 6: only the documented "missing data" failure modes fall
            # through to the legacy resolver. Schema/format errors (JSONDecode,
            # ValueError, TypeError) propagate so the caller sees the real bug
            # instead of a silently-empty-wardrobe character.
            try:
                resolved = resolve_character_for_episode(char_key, episode)
                return {
                    "name": resolved["display_name"],
                    "visual": resolved["visual_description"],
                    "wardrobe": resolved["wardrobe_desc"],
                    "identity_type": "non_human"
                    if "android" in resolved.get("visual_description", "").lower()
                    else "human",
                }
            except (FileNotFoundError, KeyError):
                return {"name": char_key.title(), "visual": "", "wardrobe": ""}


# ── Helper functions (used by strategies) ─────────────────────────────


def _check_budget(context: PipelineContext, estimated_cost: float) -> None:
    """Raise BudgetExceededError if generation would exceed budget cap."""
    if context.budget_cap <= 0:
        return  # No budget enforcement
    current = context.cost_tracker.total_cost
    if current + estimated_cost > context.budget_cap:
        raise BudgetExceededError(current, estimated_cost, context.budget_cap)


def _generate_image(
    package: PromptPackage, context: PipelineContext
) -> Optional[bytes]:
    """Generate an image via the unified API client. Returns image bytes."""
    if context.dry_run:
        logger.info(f"  [DRY RUN] Would generate with {package.model}")
        return None

    try:
        client = get_client(package.model)
        result = client.generate(package)
        if result.success and result.image_data:
            return result.image_data
        if not result.success:
            logger.warning(
                "  Generation failed for shot %s: %s",
                package.shot_id,
                result.error,
            )
    except NotImplementedError as e:
        logger.warning("  Model not yet implemented: %s", e)
    except Exception as e:
        logger.error("  API call failed: %s", e)

    return None


def _split_grid(image_data: bytes, grid_type: GridType) -> list[bytes]:
    """Split a grid image into individual panels."""
    if not _HAS_PIL:
        logger.error("Pillow not installed. Cannot split grid.")
        return []

    img = Image.open(BytesIO(image_data))
    width, height = img.size

    if grid_type == GridType.SCENE_COVERAGE:
        rows, cols = 3, 3
    else:
        rows, cols = 2, 2

    panel_w = width // cols
    panel_h = height // rows

    panels = []
    for row in range(rows):
        for col in range(cols):
            left = col * panel_w
            upper = row * panel_h
            right = left + panel_w
            lower = upper + panel_h
            panel = img.crop((left, upper, right, lower))
            # Convert to RGB (drop alpha) and save as JPEG for ~80% disk savings
            if panel.mode in ("RGBA", "P"):
                panel = panel.convert("RGB")
            buf = BytesIO()
            panel.save(buf, format="JPEG", quality=85)
            panels.append(buf.getvalue())

    return panels


def _save_output(
    image_data: bytes, shot: dict, context: PipelineContext, suffix: str = ""
) -> Path:
    """Save generated image to output directory."""
    ep_prefix = _ep_prefix(context)
    scene_index = shot.get("_scene_index", 0)
    shot_name = shot.get("name", f"shot_{shot.get('id', 0)}")
    safe_name = "".join(c if c.isalnum() or c in ("_", "-") else "_" for c in shot_name)

    filename = f"{ep_prefix}_S{scene_index:02d}_{safe_name}{suffix}.png"
    output_path = context.output_dir / filename
    output_path.write_bytes(image_data)
    logger.info(f"  Saved: {output_path.name}")
    return output_path


def _ep_prefix(context: PipelineContext) -> str:
    """Build episode file prefix: {PRJ}_EP{NNN}"""
    project_prefix = context.project[:3].upper()
    return f"{project_prefix}_EP{context.episode:03d}"


# ── Pipeline (thin orchestrator) ──────────────────────────────────────


class Pipeline:
    """Main generation orchestrator.

    Manages the generation lifecycle:
    1. Load episode data from Recoil
    2. Plan scenes (ENV first ordering)
    3. Route shots to sub-pipelines
    4. Group multi-shot batches
    5. Dispatch to strategies
    6. Track costs + enforce budget
    7. Save outputs + log
    """

    def __init__(self, episode: int = 1, project: str = None, dry_run: bool = False):
        if project is None:
            project = DEFAULT_PROJECT
        self.episode = episode
        self.project = project
        self.dry_run = dry_run
        self.session_id = str(uuid.uuid4())  # ADR H05: session-based crash recovery

        # Load project config first — needed to determine project_type
        self._project_config_path = projects_root() / project / "project_config.json"
        if self._project_config_path.exists():
            import json as _json

            self.config = _json.loads(
                self._project_config_path.read_text(encoding="utf-8")
            )
        else:
            self.config = load_project_config(project)

        from recoil.core.project import get_project as _get_project

        _proj = _get_project(self.project) if self.project else None
        if _proj is not None and _proj.is_client_deliverable:
            from recoil.pipeline._lib.client_bridge import (
                load_client_storyboard,
                load_client_bible,
            )

            self.storyboard = load_client_storyboard(project, episode)
            self.breakdown = load_client_bible(project)
        else:
            self.storyboard = load_storyboard(episode, project)
            self.breakdown = load_breakdown(project)
        self.assets = AssetManager()
        self.cost_tracker = CanonicalCostTracker(episode, project=project)

        # Load starsend config
        self.starsend_config = get_config()
        if not self.starsend_config:
            raise FileNotFoundError(f"Starsend config required: {CONFIG_PATH}")

        # Prepare output directories (v2 layout: sequences/, renders/)
        _proj_paths = ProjectPaths.for_project(project)
        self._output_dir = _proj_paths.episode_prep_dir(episode)
        self._video_dir = _proj_paths.episode_renders_dir(episode)
        self._grids_dir = self._output_dir / "grids"
        self._panels_dir = self._output_dir / "panels"
        self._scene_refs_dir = self._output_dir / "scene_refs"

        if not self.dry_run:
            self._output_dir.mkdir(parents=True, exist_ok=True)
            self._video_dir.mkdir(parents=True, exist_ok=True)
            self._grids_dir.mkdir(exist_ok=True)
            self._panels_dir.mkdir(exist_ok=True)
            self._scene_refs_dir.mkdir(exist_ok=True)

        # Layer 1: Log for shot state tracking
        self.log = EpisodeLog(episode, self._output_dir)

        # ExecutionStore — canonical shot/take tracking for mobile + console UIs
        self._exec_store = None
        try:
            self._exec_store = ExecutionStore(project=project)
        except Exception as e:
            logger.warning("Could not open ExecutionStore: %s", e)

        # Layer 2: Execution plan for dynamic state + crash recovery
        self.execution_plan = None
        try:
            from orchestrator.execution_plan import ExecutionPlan

            episode_id = f"EP{episode:03d}"
            self.execution_plan = ExecutionPlan(episode_id)
            # Try to load existing execution plan
            self.execution_plan.load()
            # Detect orphans from crashed sessions
            orphans = self.execution_plan.detect_orphans()
            if orphans:
                logger.warning(
                    "Session %s: %d orphaned shots from previous run",
                    self.session_id[:8],
                    len(orphans),
                )
                for orphan in orphans:
                    self.execution_plan.recover_orphan(orphan["shot_id"])
                self.execution_plan.checkpoint()
        except ImportError:
            pass  # execution_plan module not loaded

        self.assembler = ShotAssembler(
            breakdown=self.breakdown,
            project_config=self.config,
            starsend_config=self.starsend_config,
            asset_manager=self.assets,
        )

        # Pre-resolve character data for the episode
        self._character_data: dict[str, dict] = {}
        self._resolve_episode_characters()
        from recoil.core.project import get_project as _get_project

        _proj = _get_project(self.project) if self.project else None
        if _proj is not None and _proj.is_client_deliverable:
            assert self._character_data, (
                f"Client project '{project}' resolved zero characters. "
                f"Check client_bible.json has a 'characters' dict with entries."
            )
        self._scene_refs: dict[int, Path] = {}

        # Create StepRunner for video generation (shared across strategies)
        if self._exec_store:
            self._step_runner = StepRunner(
                store=self._exec_store,
                paths=EpisodeProjectPaths.for_episode(self.project, self.episode),
                episode=self.episode,
            )
        else:
            self._step_runner = None

        # Budget cap from config
        budget_cap = self.starsend_config.get("routing", {}).get(
            "budget_cap_per_episode", 0.0
        )

        # Build shared context for strategies
        self.context = PipelineContext(
            episode=episode,
            project=project,
            dry_run=dry_run,
            storyboard=self.storyboard,
            breakdown=self.breakdown,
            project_config=self.config,
            starsend_config=self.starsend_config,
            asset_manager=self.assets,
            cost_tracker=self.cost_tracker,
            log=self.log,
            execution_store=self._exec_store,
            step_runner=self._step_runner,
            assembler=self.assembler,
            character_data=self._character_data,
            scene_refs=self._scene_refs,
            output_dir=self._output_dir,
            video_dir=self._video_dir,
            grids_dir=self._grids_dir,
            panels_dir=self._panels_dir,
            scene_refs_dir=self._scene_refs_dir,
            budget_cap=budget_cap,
        )

    # ── Extraction orchestration (ADR H08) ──────────────────────────

    @staticmethod
    def run_extraction(
        project: str = None,
        project_root: Path = None,
        episode_nums: list[int] | None = None,
        extraction_model: str = "opus-4.6",
        batch_size: int = 5,
        dry_run: bool = False,
    ) -> dict:
        """Dispatch extraction sub-agent skills for render plan generation.

        Orchestrates the 3-stage pipeline (camera-test → global-bible →
        storyboard-pass → prompt-gen) with parallelism:
          - Camera tests: parallel across episode batches
          - Global bible: sequential (each batch merges with accumulated bible)
          - Storyboard passes: parallel per-episode (after bible complete)
          - Prompt generation: parallel per-scene within an episode

        Each sub-agent validates before writing output; failure → retry
        with error context (up to 2 retries).

        Returns:
            dict with keys: camera_tested, bible, logs, prompts
        """
        from orchestrator.ingest_pipeline import IngestPipeline

        if project is None:
            project = DEFAULT_PROJECT

        if project_root is None:
            from recoil.core.paths import RECOIL_ROOT

            project_root = RECOIL_ROOT / project

        pipeline = IngestPipeline(
            project=project,
            project_root=project_root,
            dry_run=dry_run,
            extraction_model=extraction_model,
            batch_size=batch_size,
        )

        if episode_nums is None:
            episode_nums = pipeline._episode_ids()

        results = {
            "camera_tested": {},
            "bible": None,
            "logs": {},
            "extraction_model": extraction_model,
            "batch_size": batch_size,
        }

        # Stage 0: Camera-test in parallel batches
        for i in range(0, len(episode_nums), batch_size):
            batch = episode_nums[i : i + batch_size]
            logger.info(
                "Extraction: camera-test batch %d/%d (EP%03d-EP%03d)",
                i // batch_size + 1,
                (len(episode_nums) + batch_size - 1) // batch_size,
                batch[0],
                batch[-1],
            )
            for ep_num in batch:
                try:
                    ct = pipeline.run_camera_test(episode_num=ep_num)
                    if ct is not None:
                        results["camera_tested"][ep_num] = ct
                except Exception as e:
                    logger.error("Camera-test EP%03d failed: %s", ep_num, e)

        if dry_run:
            logger.info("[DRY RUN] Extraction dry-run complete")
            return results

        # Stage 1: Global bible (sequential — each batch merges)
        try:
            bible = pipeline.run_breakdown_pass(episode_nums, merge=True)
            results["bible"] = bible
        except Exception as e:
            logger.error("Breakdown pass failed: %s", e)
            return results

        if bible is None:
            return results

        # Stage 2: Storyboard passes (parallel per-episode)
        for ep_num in episode_nums:
            if ep_num not in results["camera_tested"]:
                continue
            try:
                plan_result = pipeline.run_storyboard_pass(ep_num, bible)
                if plan_result is not None:
                    results["logs"][ep_num] = plan_result
            except Exception as e:
                logger.error("Storyboard EP%03d failed: %s", ep_num, e)

        logger.info(
            "Extraction complete: %d camera-tested, %s bible, %d logs",
            len(results["camera_tested"]),
            "1" if results["bible"] else "no",
            len(results["logs"]),
        )
        return results

    def _resolve_episode_characters(self):
        """Pre-resolve visual data for all characters in this episode.

        Plan-first: when using bible, resolves from bible characters.
        Client: flat resolution from client_bible (no phases).
        Legacy: resolves from storyboard characters block.
        """
        source = self.storyboard.get("_source", "recoil")
        from recoil.core.project import get_project as _get_project

        _proj = _get_project(self.project) if self.project else None
        is_client = _proj is not None and _proj.is_client_deliverable

        if source == "manifest":
            char_ids = set()
            for shot in self.storyboard.get("shots", []):
                for c in shot.get("asset_data", {}).get("characters", []):
                    char_ids.add(c.get("char_id", ""))
            char_keys = [c for c in char_ids if c]
        else:
            char_keys = list(self.storyboard.get("characters", {}).keys())

        for char_key in char_keys:
            if is_client:
                # Client video: flat resolution from client_bible (no phases, no handoff)
                bible_char = self.breakdown.get("characters", {}).get(char_key, {})
                if bible_char:
                    self._character_data[char_key.lower()] = {
                        "name": bible_char.get("display_name", char_key),
                        "visual": bible_char.get("visual_description", ""),
                        "wardrobe": bible_char.get("wardrobe_description", ""),
                        "hair_makeup": bible_char.get("hair_makeup_description", ""),
                        "height_cm": bible_char.get("height_cm", 170),
                        "distinguishing_marks": bible_char.get(
                            "distinguishing_marks", ""
                        ),
                        "identity_type": bible_char.get("identity_type", "human"),
                    }
                else:
                    logger.warning(f"Client bible missing character: {char_key}")
                continue

            try:
                from recoil.pipeline._lib.render_schema import validate_handoff

                handoff = validate_handoff(char_key, self.episode, self.breakdown)
                self._character_data[char_key.lower()] = {
                    "name": handoff.display_name,
                    "visual": handoff.visual_description,
                    "wardrobe": handoff.wardrobe_description,
                    "hair_makeup": handoff.hair_makeup_description,
                    "height_cm": handoff.height_cm,
                    "distinguishing_marks": handoff.distinguishing_marks,
                    "identity_type": handoff.identity_type,
                }
            except Exception as e:
                logger.warning(
                    f"Handoff validation failed for {char_key}: {e}, using legacy resolve"
                )
                try:
                    resolved = resolve_character_for_episode(
                        char_key, self.episode, self.project
                    )
                    self._character_data[char_key.lower()] = {
                        "name": resolved["display_name"],
                        "visual": resolved["visual_description"],
                        "wardrobe": resolved["wardrobe_desc"],
                        "hair_makeup": resolved["hair_makeup"],
                        "height_cm": resolved["height_cm"],
                        "distinguishing_marks": resolved.get(
                            "distinguishing_marks", ""
                        ),
                        "identity_type": "non_human"
                        if "android" in resolved.get("visual_description", "").lower()
                        else "human",
                    }
                except (FileNotFoundError, KeyError) as e2:
                    logger.warning(f"Could not resolve character {char_key}: {e2}")

    # ── Public API ────────────────────────────────────────────────────

    def run(self, shot_ids: Optional[list] = None, tier_override: Optional[str] = None):
        """Run the pipeline for specified shots (or all shots).

        shot_ids accepts both int (legacy) and str (plan, e.g. "EP001_SH01").
        """
        scenes = get_all_scenes(self.storyboard)
        # Partition long scenes (>8 shots) into valid sub-batches before planning
        partitioned_scenes = []
        for scene in scenes:
            partitioned_scenes.extend(partition_long_scene(scene))
        planned_scenes = plan_episode(partitioned_scenes)

        # Normalize shot_ids for filtering
        shot_id_set = None
        if shot_ids is not None:
            shot_id_set = set()
            for sid in shot_ids:
                shot_id_set.add(sid)
                # Also add integer form for plan shots
                if isinstance(sid, int):
                    shot_id_set.add(f"EP{self.episode:03d}_SH{sid:02d}")

        # Flatten, filter, and route shots
        shots_to_run: list[dict] = []
        multi_shot_groups: list[list[dict]] = []

        for scene in planned_scenes:
            scene_shot_list = list(scene)
            current_multi_group: list[dict] = []

            for shot in scene:
                # Get shot identifier for filtering
                shot_identifier = shot.get("shot_id", shot.get("id", 0))

                if shot_id_set is None or shot_identifier in shot_id_set:
                    if tier_override:
                        shot["_tier"] = tier_override

                    # Route shot to sub-pipeline
                    # --tier simple forces ALL shots through StillPipeline
                    if tier_override == "simple":
                        from recoil.core.paths import get_config as _get_cfg

                        _still_model = (
                            _get_cfg()
                            .get("routing", {})
                            .get("models", {})
                            .get("still", "gemini-3-pro-image-preview")
                        )
                        routing = {
                            "pipeline": "still",
                            "model": _still_model,
                            "reason": "tier=simple override → still pipeline",
                        }
                    else:
                        routing = route_shot(shot, scene_shot_list)
                    shot["_pipeline"] = routing["pipeline"]
                    shot["_target_model"] = routing["model"]
                    shot["_routing_reason"] = routing["reason"]

                    if routing["pipeline"] != "still":
                        logger.info(
                            "  Shot %s → %s pipeline (%s) — %s",
                            shot_identifier,
                            routing["pipeline"],
                            routing["model"],
                            routing["reason"],
                        )

                    # Group consecutive multi_shot shots together
                    if routing["pipeline"] == "multi_shot":
                        current_multi_group.append(shot)
                    else:
                        # Flush any pending multi-shot group
                        if current_multi_group:
                            multi_shot_groups.append(current_multi_group)
                            current_multi_group = []
                        shots_to_run.append(shot)

            # Flush remaining multi-shot group at end of scene
            if current_multi_group:
                multi_shot_groups.append(current_multi_group)

        if not shots_to_run and not multi_shot_groups:
            logger.warning("No shots to generate.")
            return

        # Initialize log for all shots
        all_shot_ids = [_shot_id(s) for s in shots_to_run]
        for group in multi_shot_groups:
            all_shot_ids.extend(_shot_id(s) for s in group)
        self.log.init_shots(all_shot_ids)

        total = len(shots_to_run) + sum(len(g) for g in multi_shot_groups)
        logger.info(
            f"Pipeline: EP{self.episode:03d} — {total} shots "
            f"({'DRY RUN' if self.dry_run else 'LIVE'})"
        )
        if multi_shot_groups:
            logger.info(
                f"  Multi-shot groups: {len(multi_shot_groups)} "
                f"({sum(len(g) for g in multi_shot_groups)} shots)"
            )

        results: list[dict] = []

        # Execute individual shots
        for i, shot in enumerate(shots_to_run, 1):
            shot_id = _shot_id(shot)

            # Skip manually escalated shots (being handled in Manual Workbench)
            if self._exec_store:
                store_shot = self._exec_store.get_shot(shot_id)
                if store_shot:
                    sg = store_shot.get("gate_results", {})
                    if sg.get("manual_escalated") and not sg.get("manual_resolved"):
                        logger.info(
                            f"  Skipping {shot_id} — manually escalated, in workbench"
                        )
                        continue

            tier = shot.get("_tier", "standard")
            pipeline_name = shot.get("_pipeline", "still")
            scene_index = shot.get("_scene_index", 0)

            logger.info(
                f"[{i}/{total}] Shot {shot_id} — tier={tier}, "
                f"pipeline={pipeline_name}, scene={scene_index}"
            )

            self.log.update_shot(
                shot_id,
                "submitted",
                tier=tier,
                pipeline=pipeline_name,
                model=shot.get("_target_model"),
            )

            try:
                strategy = get_strategy(pipeline_name)
                result = strategy.execute(shot, self.context)
                results.append(result)

                # Update log
                status = result.get("status", "error")
                if status == "ok":
                    self.log.update_shot(
                        shot_id,
                        "complete",
                        output_path=result.get("output"),
                        cost=result.get(
                            "cost"
                        ),  # intentional None shape — not migrated
                    )
                    # Register in ExecutionStore
                    self._register_in_store(
                        shot_id,
                        result,
                        pipeline_name,
                        shot.get("_target_model"),
                        "keyframe_generated",
                    )
                elif status == "dry_run":
                    self.log.update_shot(shot_id, "complete")
                else:
                    self.log.update_shot(
                        shot_id,
                        "failed",
                        error=result.get("error"),
                    )
                    # Register failure in store
                    self._register_in_store(
                        shot_id,
                        result,
                        pipeline_name,
                        shot.get("_target_model"),
                        "failed",
                    )

            except NotImplementedError as e:
                logger.warning(f"Shot {shot_id}: {e}")
                results.append(
                    {
                        "shot_id": shot_id,
                        "status": "skipped",
                        "reason": str(e),
                    }
                )
                self.log.update_shot(shot_id, "failed", error=str(e))
            except BudgetExceededError as e:
                logger.error(f"Budget exceeded at shot {shot_id}: {e}")
                results.append(
                    {
                        "shot_id": shot_id,
                        "status": "budget_exceeded",
                        "error": str(e),
                    }
                )
                self.log.update_shot(shot_id, "failed", error=str(e))
                break  # Stop processing more shots
            except Exception as e:
                logger.error(f"Shot {shot_id} failed: {e}")
                results.append(
                    {
                        "shot_id": shot_id,
                        "status": "error",
                        "error": str(e),
                    }
                )
                self.log.update_shot(shot_id, "failed", error=str(e))

        # Execute multi-shot groups
        for group in multi_shot_groups:
            # Filter out manually escalated shots from the group
            if self._exec_store:
                filtered_group = []
                for s in group:
                    sid = s.get("id", 0)
                    store_shot = self._exec_store.get_shot(sid)
                    if store_shot:
                        sg = store_shot.get("gate_results", {})
                        if sg.get("manual_escalated") and not sg.get("manual_resolved"):
                            logger.info(
                                f"  Skipping {sid} — manually escalated, in workbench"
                            )
                            continue
                    filtered_group.append(s)
                group = filtered_group
                if not group:
                    continue

            group_ids = [s.get("id", 0) for s in group]
            logger.info(f"Multi-shot group: shots {group_ids}")
            try:
                strategy = get_strategy("multi_shot")
                group_results = strategy.execute_batch(group, self.context)
                results.extend(group_results)
                # Register each result in ExecutionStore
                for gr in group_results:
                    gr_status = "video_ready" if gr.get("status") == "ok" else "failed"
                    self._register_in_store(
                        gr.get("shot_id"),
                        gr,
                        gr.get("pipeline", "multi_shot"),
                        gr.get("model"),
                        gr_status,
                    )
            except NotImplementedError as e:
                logger.warning(f"Multi-shot group skipped: {e}")
                for shot in group:
                    results.append(
                        {
                            "shot_id": shot.get("id", 0),
                            "status": "skipped",
                            "reason": str(e),
                        }
                    )
                    self.log.update_shot(shot.get("id", 0), "failed", error=str(e))

        # Save log, execution plan, and cost log
        self._save_log(results)
        self.log.save()
        self.cost_tracker.save_log()
        if self.execution_plan:
            self.execution_plan.checkpoint()

        # Summary
        success = sum(1 for r in results if r.get("status") == "ok")
        failed = sum(1 for r in results if r.get("status") == "error")
        skipped = sum(1 for r in results if r.get("status") == "skipped")
        logger.info(
            f"Pipeline complete: {success} succeeded, {failed} failed, "
            f"{skipped} skipped, total cost: ${self.cost_tracker.total_cost:.2f}"
        )
        logger.info(self.cost_tracker.summary())

    def _register_in_store(self, shot_id, result, pipeline_name, model, store_status):
        """Register a pipeline failure in the ExecutionStore (success handled by StepRunner)."""
        if not self._exec_store or not shot_id:
            return
        if store_status == "failed":
            try:
                self._exec_store.update_shot(
                    shot_id,
                    status="failed",
                    error_message=result.get("error", ""),
                )
            except Exception as e:
                logger.warning("Failed to register error for %s: %s", shot_id, e)

    def compile_shot(
        self, shot_id: int, tier_override: Optional[str] = None
    ) -> PromptPackage:
        """Compile a single shot into a PromptPackage without generating."""
        shot = get_shot_by_id(self.storyboard, shot_id)
        if shot is None:
            raise ValueError(
                f"Shot {shot_id} not found in EP{self.episode:03d} storyboard"
            )

        scene_index = self._find_scene_index(shot_id)
        scene_ref = self._scene_refs.get(scene_index)

        return self.assembler.compile_production_package(
            shot=shot,
            storyboard=self.storyboard,
            episode=self.episode,
            scene_ref_path=scene_ref,
            character_data=self._character_data,
        )

    # ── Output helpers ────────────────────────────────────────────────

    def _save_log(self, results: list[dict]):
        """Save generation log to sequences/ep_{NNN}/log.json."""
        log = {
            "episode": self.episode,
            "project": self.project,
            "timestamp": time.time(),
            "dry_run": self.dry_run,
            "total_cost": round(self.cost_tracker.total_cost, 4),
            "total_calls": self.cost_tracker.total_calls,
            "failed_calls": self.cost_tracker.failed_calls,
            "cost_by_model": {
                k: round(v, 4) for k, v in self.cost_tracker.cost_by_model().items()
            },
            "cost_by_tier": {
                k: round(v, 4) for k, v in self.cost_tracker.cost_by_tier().items()
            },
            "shots": results,
        }
        log_path = self._output_dir / "log.json"

        if self.dry_run:
            print(f"\n=== Log (would save to {log_path}) ===")
            print(json.dumps(log, indent=2, default=str))
        else:
            log_path.write_text(
                json.dumps(log, indent=2, default=str), encoding="utf-8"
            )
            logger.info(f"Log saved: {log_path}")

    def _find_scene_index(self, shot_id: int) -> int:
        """Find which scene a shot belongs to."""
        scenes = get_all_scenes(self.storyboard)
        for scene_index, scene_shots in enumerate(scenes):
            for shot in scene_shots:
                if shot.get("id") == shot_id:
                    return scene_index
        return 0


# ── CLI ──────────────────────────────────────────────────────────────


def _parse_shot_range(shot_str: str) -> list[int]:
    """Parse shot ID string like '1-3' or '1,2,5' into a list of ints."""
    ids = []
    for part in shot_str.split(","):
        part = part.strip()
        if "-" in part:
            start, end = part.split("-", 1)
            ids.extend(range(int(start), int(end) + 1))
        else:
            ids.append(int(part))
    return ids


def main():
    import argparse

    parser = argparse.ArgumentParser(
        description="Starsend generation pipeline",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Dry run all shots in EP001
  python -m orchestrator.pipeline --episode 1 --dry-run

  # Generate specific shots
  python -m orchestrator.pipeline --episode 1 --shots 1-3

  # Force simple tier for all shots
  python -m orchestrator.pipeline --episode 1 --tier simple --dry-run

  # Preview a single shot's compiled prompt
  python -m orchestrator.pipeline --episode 1 --shots 5 --dry-run
""",
    )

    parser.add_argument(
        "--episode",
        "-e",
        type=int,
        required=True,
        help="Episode number to generate",
    )
    parser.add_argument(
        "--shots",
        "-s",
        type=str,
        default=None,
        help="Shot IDs to generate (e.g. '1-3' or '1,2,5'). Default: all shots.",
    )
    parser.add_argument(
        "--tier",
        "-t",
        type=str,
        default=None,
        choices=["simple", "standard", "complex"],
        help="Force a complexity tier for all shots.",
    )
    parser.add_argument(
        "--dry-run",
        "-d",
        action="store_true",
        help="Compile prompts but don't call API. Shows PromptPackage for each shot.",
    )
    parser.add_argument(
        "--project",
        "-p",
        type=str,
        default=None,
        help="Project name (default: from pipeline_config.json)",
    )
    parser.add_argument(
        "--verbose",
        "-v",
        action="store_true",
        help="Enable verbose logging",
    )

    args = parser.parse_args()

    # Configure logging
    log_level = logging.DEBUG if args.verbose else logging.INFO
    logging.basicConfig(
        level=log_level,
        format="%(asctime)s [%(name)s] %(levelname)s: %(message)s",
        datefmt="%H:%M:%S",
    )

    # Parse shot IDs
    shot_ids = None
    if args.shots:
        shot_ids = _parse_shot_range(args.shots)

    # Run pipeline
    pipeline = Pipeline(
        episode=args.episode,
        project=args.project,
        dry_run=args.dry_run,
    )

    if args.dry_run and shot_ids and len(shot_ids) == 1:
        # Single shot preview — compile and display
        package = pipeline.compile_shot(shot_ids[0], tier_override=args.tier)
        print(package.describe())
    else:
        pipeline.run(shot_ids=shot_ids, tier_override=args.tier)


if __name__ == "__main__":
    main()
