#!/usr/bin/env python3
"""recoil produce — unified production command.

Chains: validate episode -> generate breakdown -> compile manifest -> (execute via StepRunner)

Usage:
    python3 tools/produce.py <project> <episode_id> [--dry-run] [--stage] [--force]

Examples:
    python3 tools/produce.py afterimage ep_01 --dry-run
    python3 tools/produce.py afterimage ep_01 --stage    # pause between stages
    python3 tools/produce.py afterimage all               # all episodes
"""
import argparse
import json
import sys
from pathlib import Path

RECOIL_ROOT = Path(__file__).resolve().parent.parent

sys.path.insert(0, str(RECOIL_ROOT))

from recoil.core.paths import projects_root  # noqa: E402


def execute_manifest(manifest: dict, project_name: str, episode_num: int,
                      dry_run: bool = False) -> list:
    """Walk the compiler manifest's task list and execute each task via StepRunner."""
    import os
    from recoil.execution.step_runner import StepRunner
    from recoil.execution.step_types import ProjectPaths
    from recoil.execution.execution_store import ExecutionStore
    from recoil.pipeline.core.dispatch import dispatch
    from recoil.pipeline.core.dispatch_context import DispatchContext

    if dry_run or os.environ.get('RECOIL_DRY_RUN'):
        print("DRY RUN: Would execute manifest with StepRunner")
        for task in manifest.get('tasks', []):
            print(f"  Task: {task.get('task_id')} ({task.get('type')})")
        return []

    store = ExecutionStore(project_name)
    paths = ProjectPaths.for_episode(project_name, episode_num)
    runner = StepRunner(store=store, paths=paths, episode=episode_num)

    ctx = DispatchContext(
        caller_id="produce_cli",
        step_runner=runner,
        project=project_name,
        episode=episode_num,
    )

    results = []
    completed_tasks = {}  # task_id -> output_path for chaining

    for task in manifest.get('tasks', []):
        task_id = task.get('task_id', '')
        task_type = task.get('type', '')
        shot_id = task_id.rsplit('_', 1)[0] if '_' in task_id else task_id

        if task_type == 'image_generation':
            receipt = dispatch(
                "image_t2i",
                {
                    "shot_id": shot_id,
                    "prompt": task.get('prompt', ''),
                    "model": task.get('model', 'gemini-3-pro-image-preview'),
                    "aspect_ratio": task.get('aspect_ratio', '9:16'),
                },
                context=ctx,
            )
            result = receipt.run_result
            if result.success and result.output_path:
                completed_tasks[task_id] = result.output_path
            results.append(result)

        elif task_type == 'video_generation':
            start_frame = None
            source = task.get('source_task')
            if source and source in completed_tasks:
                start_frame = paths.project_root / completed_tasks[source]

            receipt = dispatch(
                "video_i2v",
                {
                    "shot_id": shot_id,
                    "prompt": task.get('prompt', ''),
                    "model": task.get('model', 'kling-v3'),
                    "start_frame": start_frame,
                    "duration": task.get('duration', 5),
                    "aspect_ratio": task.get('aspect_ratio', '9:16'),
                },
                context=ctx,
            )
            result = receipt.run_result
            results.append(result)

        else:
            print(f"  SKIP: Unknown task type '{task_type}' for {task_id}")

    return results


def detect_format(project_path: Path) -> str:
    """Read format from project_config.json."""
    config_path = project_path / 'project_config.json'
    if config_path.exists():
        return json.loads(config_path.read_text()).get('format', 'kill_box')
    return 'kill_box'


def produce_episode(project_name: str, episode_id: str,
                     dry_run: bool = False, stage: bool = False,
                     force: bool = False) -> bool:
    """Run the full production pipeline for a single episode.

    Returns True if all stages pass.
    """
    project_path = projects_root() / project_name
    if not project_path.exists():
        print(f"ERROR: Project not found: {project_path}")
        return False

    format_name = detect_format(project_path)
    episode_path = project_path / 'episodes' / f"{episode_id}.md"

    if not episode_path.exists():
        print(f"ERROR: Episode not found: {episode_path}")
        return False

    print(f"{'='*60}")
    print(f"PRODUCE: {project_name} / {episode_id}")
    print(f"Format: {format_name}")
    print(f"{'='*60}")

    # Stage 1: Validate episode
    print("\n--- Stage 1: Validate Episode ---")
    if not force:
        from tools.validate_episode import dispatch_validate_episode
        result = dispatch_validate_episode(str(episode_path), project_path=str(project_path))
        if not result.get('valid'):
            print("FAIL: Episode validation failed")
            for err in result.get('errors', []):
                print(f"  - {err}")
            return False
        print("PASS: Episode valid")
    else:
        print("SKIP: --force flag set")

    if stage:
        input("Press Enter to continue to Stage 2 (Breakdown)...")

    # Stage 2: Generate breakdown
    print("\n--- Stage 2: Generate Breakdown ---")
    from visual.breakdown_agent import generate_breakdown
    breakdown = generate_breakdown(episode_path, project_path, format_name=format_name, dry_run=dry_run)

    if dry_run:
        print("DRY RUN: Would generate breakdown with this prompt:")
        print(breakdown.get('prompt', '')[:500] + '...')
        return True

    breakdown_path = project_path / 'state' / 'visual' / 'breakdowns' / f"{episode_id}_breakdown.json"

    # Validate breakdown
    from visual.validate_breakdown import validate_breakdown
    bd_result = validate_breakdown(breakdown_path)
    if not bd_result['valid']:
        print("FAIL: Breakdown validation failed")
        for err in bd_result.get('errors', []):
            print(f"  - {err}")
        return False
    print(f"PASS: Breakdown valid ({bd_result['metrics']['total_shots']} shots)")

    if stage:
        input("Press Enter to continue to Stage 3 (Compile)...")

    # Stage 3: Compile manifest
    print("\n--- Stage 3: Compile Manifest ---")
    from visual.compiler import compile_from_breakdown
    manifest = compile_from_breakdown(breakdown_path, project_path, format_name=format_name)

    task_count = len(manifest.get('tasks', []))
    task_types = set(t['type'] for t in manifest.get('tasks', []))
    print(f"PASS: Manifest compiled ({task_count} tasks — {task_types})")

    if stage:
        input("Press Enter to continue to Stage 4 (Render)...")

    # Stage 4: Execute via StepRunner
    print("\n--- Stage 4: Execute ---")
    if dry_run:
        execute_manifest(manifest, project_name, int(episode_id.replace('ep_', '')), dry_run=True)
    else:
        results = execute_manifest(manifest, project_name, int(episode_id.replace('ep_', '')))
        passed = sum(1 for r in results if r.success)
        failed = sum(1 for r in results if not r.success)
        print(f"DONE: {passed} passed, {failed} failed out of {len(results)} tasks")

    manifest_path = project_path / 'state' / 'visual' / 'manifests' / f"{episode_id}_manifest.json"
    print(f"Manifest: {manifest_path}")

    print(f"\n{'='*60}")
    print(f"COMPLETE: {project_name} / {episode_id}")
    print(f"{'='*60}")
    return True


def main():
    parser = argparse.ArgumentParser(description='Produce visual output from narrative episodes')
    parser.add_argument('project', help='Project name (e.g., afterimage)')
    parser.add_argument('episode', help='Episode ID (e.g., ep_01) or "all"')
    parser.add_argument('--dry-run', action='store_true', help='Show what would happen without executing')
    parser.add_argument('--stage', action='store_true', help='Pause between pipeline stages')
    parser.add_argument('--force', action='store_true', help='Skip episode validation')
    args = parser.parse_args()

    if args.episode == 'all':
        project_path = projects_root() / args.project
        episodes = sorted(project_path.glob('episodes/ep_*.md'))
        results = []
        for ep in episodes:
            ep_id = ep.stem
            success = produce_episode(args.project, ep_id,
                                       dry_run=args.dry_run, stage=args.stage, force=args.force)
            results.append((ep_id, success))

        print(f"\n{'='*60}")
        print(f"SUMMARY: {sum(1 for _, s in results if s)}/{len(results)} episodes produced")
        for ep_id, success in results:
            print(f"  {'PASS' if success else 'FAIL'}: {ep_id}")
    else:
        success = produce_episode(args.project, args.episode,
                                   dry_run=args.dry_run, stage=args.stage, force=args.force)
        sys.exit(0 if success else 1)


if __name__ == '__main__':
    main()
