#!/usr/bin/python3
"""
Save Checkpoint with Validation Gate

Creates a batch checkpoint ONLY after validating all episodes pass:
1. V12 mechanical criteria (validate_batch.py)
2. Dramatic quality criteria (quality_gate.py)
3. Dramatic QC gate (dramatic_qc_gate.py) — voice, texture, behavioral DNA
4. Batch summary generation (generate_batch_summary.py) — for orchestrator
5. Orchestrator state update (update_orchestrator_state.py) — cross-batch tracking

This is the ONLY way to unlock the next batch - episodes must pass ALL gates.

Usage: python3 save_checkpoint.py <project_path> <batch_number> [--regen]
Example: python3 save_checkpoint.py ./singularity 1  (after completing eps 1-5)
         python3 save_checkpoint.py ./leviathan 1 --regen  (regenerating batch 1 only, keep eps 6-60)

Flow:
0. [Batch 1 only, unless --regen] If episodes already exist, back them up and clear folder
1. Runs validate_batch.py on the batch (mechanical: word count, dialogue %, etc.)
2. If mechanical validation FAILS → checkpoint NOT saved, next batch BLOCKED
3. Runs quality_gate.py on the batch (dramatic: variety, beats, continuity)
4. If quality gate FAILS → checkpoint NOT saved, next batch BLOCKED
5. Runs dramatic_qc_gate.py on the batch (voice, texture, behavioral DNA) - SOFT GATE
6. If dramatic QC has issues → WARNING shown, but continues (soft gate)
7. Generates batch_summary.json for orchestrator tracking
8. Updates orchestrator_state.json (if exists) with batch data
9. If ALL PASS → checkpoint saved, next batch UNLOCKED

Re-run Safety:
When running batch 1 checkpoint and >5 episodes exist, the script will:
- Back up existing episodes to /state/backups/generation_[timestamp]/
- Clear the episodes folder
- Clear old checkpoints
This prevents losing previous work when starting a fresh generation.

The --regen flag SKIPS this fresh-start behavior. Use it when regenerating
a specific batch (e.g., batch 1 with improved voice prompts) while keeping
all other episodes in place.
"""

import json
import sys
import os
import subprocess
from datetime import datetime
from pathlib import Path

# Add engine tools to path for imports
_SCRIPT_DIR = Path(__file__).parent.resolve()
_ENGINE_TOOLS = _SCRIPT_DIR.parent.parent / 'tools'
if _ENGINE_TOOLS.exists():
    sys.path.insert(0, str(_ENGINE_TOOLS))
sys.path.insert(0, str(_SCRIPT_DIR.parent.parent.parent))  # CLAUDE_PROJECTS, for recoil.*
from recoil.core.paths import ProjectPaths

try:
    from engine_constants import GENERATION_BATCH_SIZE
except ImportError:
    GENERATION_BATCH_SIZE = 5

def run_validation(project_path, batch_num):
    """Run validate_batch.py and return True if all episodes pass."""
    script_dir = Path(__file__).parent
    validate_script = script_dir / "validate_batch.py"

    if not validate_script.exists():
        print(f"ERROR: Validation script not found: {validate_script}")
        return False

    result = subprocess.run(
        [sys.executable, str(validate_script), str(project_path), str(batch_num)],
        capture_output=False  # Let output print to console
    )

    return result.returncode == 0


def run_quality_gate(project_path, batch_num):
    """Run quality_gate.py and return True if all quality checks pass."""
    script_dir = Path(__file__).parent
    quality_script = script_dir / "quality_gate.py"

    if not quality_script.exists():
        print(f"ERROR: Quality gate script not found: {quality_script}")
        return False

    result = subprocess.run(
        [sys.executable, str(quality_script), str(project_path), str(batch_num)],
        capture_output=False  # Let output print to console
    )

    return result.returncode == 0


def run_transition_gate(project_path):
    """Run transition_gate.py for full series validation. Returns (passed, needs_review)."""
    script_dir = Path(__file__).parent
    transition_script = script_dir / "transition_gate.py"

    if not transition_script.exists():
        print(f"WARNING: Transition gate script not found: {transition_script}")
        return True, False  # Skip if not available

    result = subprocess.run(
        [sys.executable, str(transition_script), str(project_path), "--check"],
        capture_output=False
    )

    # Exit codes: 0 = passed, 1 = hard fail, 2 = needs review
    if result.returncode == 0:
        return True, False
    elif result.returncode == 2:
        return True, True  # Passed but needs review
    else:
        return False, False


def run_dramatic_qc_gate(project_path, batch_num):
    """
    Run dramatic_qc_gate.py for voice, texture, behavioral DNA checks.
    This is a SOFT GATE - issues are reported but don't block.
    Returns (has_issues, issue_count).
    """
    script_dir = Path(__file__).parent
    dramatic_qc_script = script_dir / "dramatic_qc_gate.py"

    if not dramatic_qc_script.exists():
        print(f"NOTE: Dramatic QC gate not found: {dramatic_qc_script}")
        print(f"      Skipping dramatic quality checks.")
        return False, 0  # Skip if not available

    result = subprocess.run(
        [sys.executable, str(dramatic_qc_script), str(project_path), "--batch", str(batch_num)],
        capture_output=True,
        text=True
    )

    # Print output regardless of exit code
    if result.stdout:
        print(result.stdout)
    if result.stderr:
        print(result.stderr)

    # Exit code 0 = no issues, 1 = has issues (but soft gate)
    if result.returncode == 0:
        return False, 0
    else:
        # Try to extract issue count from output
        issue_count = 0
        for line in result.stdout.split('\n'):
            if 'issues found' in line.lower() or 'issue' in line.lower():
                try:
                    # Try to find a number in the line
                    import re
                    numbers = re.findall(r'\d+', line)
                    if numbers:
                        issue_count = int(numbers[0])
                except (ValueError, IndexError):
                    issue_count = 1
        return True, max(issue_count, 1)

def run_baseline_comparison(project_path, batch_num):
    """
    Run baseline_comparison.py at batches 3, 6, 9, 12.
    Returns (has_issues, issue_details).

    Purpose: Detect unintentional voice issues:
    - Contamination (Character A using Character B's patterns)
    - Generic drift (characters becoming too similar)
    - Regression (sudden voice changes from previous batch)
    """
    if batch_num not in [3, 6, 9, 12]:
        return False, None

    script_dir = Path(__file__).parent
    baseline_script = script_dir / "baseline_comparison.py"

    if not baseline_script.exists():
        print(f"NOTE: Voice contamination script not found: {baseline_script}")
        print(f"      Skipping cumulative drift check.")
        return False, None

    print(f"\n--- Running Voice Contamination Check (Batch {batch_num}) ---")

    result = subprocess.run(
        [sys.executable, str(baseline_script), str(project_path), str(batch_num)],
        capture_output=True,
        text=True
    )

    # Print output
    if result.stdout:
        print(result.stdout)
    if result.stderr:
        print(result.stderr)

    # Exit code 0 = no drift, 1 = drift detected
    if result.returncode == 0:
        return False, None
    else:
        return True, result.stdout


def run_batch_summary_generator(project_path, batch_num):
    """
    Run generate_batch_summary.py to create batch_summary.json.
    This is used by the orchestrator for cross-series verification.
    Returns True if successful.
    """
    # Look for the script in tools
    engine_tools = project_path.parent.parent / "recoil" / "tools"
    summary_script = engine_tools / "generate_batch_summary.py"

    if not summary_script.exists():
        print(f"NOTE: Batch summary generator not found: {summary_script}")
        print(f"      Skipping batch summary generation (optional for non-orchestrated mode).")
        return True  # Don't fail - this is optional

    print(f"\n--- Generating Batch Summary ---")

    result = subprocess.run(
        [sys.executable, str(summary_script), str(project_path), str(batch_num)],
        capture_output=True,
        text=True
    )

    if result.stdout:
        print(result.stdout)
    if result.stderr:
        print(result.stderr, file=sys.stderr)

    return result.returncode == 0


def run_orchestrator_state_update(project_path, batch_num):
    """
    Run update_orchestrator_state.py to update orchestrator tracking.
    Only runs if orchestrator_state.json exists (i.e., in orchestrated mode).
    Returns True if successful or if not in orchestrated mode.
    """
    orchestrator_state = ProjectPaths.from_root(project_path).state_dir / "orchestrator_state.json"
    if not orchestrator_state.exists():
        # Not in orchestrated mode - skip silently
        return True

    # Look for the script in tools
    engine_tools = project_path.parent.parent / "recoil" / "tools"
    update_script = engine_tools / "update_orchestrator_state.py"

    if not update_script.exists():
        print(f"NOTE: Orchestrator state updater not found: {update_script}")
        return True  # Don't fail - will work without it

    print(f"\n--- Updating Orchestrator State ---")

    result = subprocess.run(
        [sys.executable, str(update_script), str(project_path), str(batch_num)],
        capture_output=True,
        text=True
    )

    if result.stdout:
        print(result.stdout)
    if result.stderr:
        print(result.stderr, file=sys.stderr)

    if result.returncode != 0:
        print(f"WARNING: Orchestrator state update returned non-zero exit code")
        # Don't fail the checkpoint for this - it's supplementary

    return True


def backup_previous_generation(project_path):
    """
    If episodes exist, back them up before starting a fresh generation.
    Called only when running batch 1 checkpoint.

    Returns True if backup was made, False if no backup needed.
    """
    episodes_dir = project_path / "episodes"
    if not episodes_dir.exists():
        return False

    # Find existing episode files
    existing_episodes = list(episodes_dir.glob("ep_*.md"))
    if not existing_episodes:
        return False

    # Check if already backed up by comparing episode count with most recent backup
    backups_dir = ProjectPaths.from_root(project_path).backups_dir
    if backups_dir.exists():
        existing_backups = sorted(backups_dir.glob("generation_*"), reverse=True)
        if existing_backups:
            latest_backup = existing_backups[0]
            backed_up_episodes = list(latest_backup.glob("ep_*.md"))
            # If same count, assume already backed up (avoid duplicate backups)
            if len(backed_up_episodes) == len(existing_episodes):
                print(f"  Episodes already backed up in {latest_backup.name}")
                # Still clear the episodes folder
                for ep_file in existing_episodes:
                    ep_file.unlink()
                # Clear old checkpoints too
                checkpoints_dir = ProjectPaths.from_root(project_path).checkpoints_dir
                if checkpoints_dir.exists():
                    for cp_file in checkpoints_dir.glob("batch_*.json"):
                        cp_file.unlink()
                return False

    # Create backup folder with timestamp
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    backup_folder = backups_dir / f"generation_{timestamp}"
    backup_folder.mkdir(parents=True, exist_ok=True)

    print(f"  Backing up {len(existing_episodes)} episodes to {backup_folder.name}/")

    # Move episodes to backup
    for ep_file in existing_episodes:
        dest = backup_folder / ep_file.name
        ep_file.rename(dest)

    # Clear old checkpoints (since we're starting over)
    checkpoints_dir = ProjectPaths.from_root(project_path).checkpoints_dir
    if checkpoints_dir.exists():
        for cp_file in checkpoints_dir.glob("batch_*.json"):
            cp_file.unlink()
        print(f"  Cleared old checkpoints")

    return True


def main():
    if len(sys.argv) < 3:
        print("Usage: python3 save_checkpoint.py <project_path> <batch_number> [--regen]")
        print("Example: python3 save_checkpoint.py ./singularity 1")
        print("         python3 save_checkpoint.py ./leviathan 1 --regen")
        sys.exit(1)

    project_path = Path(sys.argv[1]).resolve()
    batch_num = int(sys.argv[2])
    regen = "--regen" in sys.argv

    if not project_path.exists():
        print(f"Error: Project path does not exist: {project_path}")
        sys.exit(1)

    # STEP 0: If batch 1 and episodes exist, backup and clear first
    # --regen flag skips this: regenerating batch 1 while keeping eps 6-60
    if batch_num == 1 and not regen:
        episodes_dir = project_path / "episodes"
        existing = list(episodes_dir.glob("ep_*.md")) if episodes_dir.exists() else []
        if len(existing) > GENERATION_BATCH_SIZE:  # More than just batch 1 exists
            print(f"\n{'#'*60}")
            print(f"# FRESH START DETECTED: Batch 1 with existing episodes")
            print(f"{'#'*60}")
            print(f"\nStep 0: Backing up previous generation...")
            backup_previous_generation(project_path)
    elif batch_num == 1 and regen:
        print(f"\n  --regen flag: skipping fresh-start backup (keeping existing episodes)")

    # STEP 1: Run mechanical validation (V12 constraints)
    print(f"\n{'#'*60}")
    print(f"# CHECKPOINT GATE: Batch {batch_num}")
    print(f"{'#'*60}")
    print(f"\nStep 1: Validating V12 mechanical constraints...")

    if not run_validation(project_path, batch_num):
        print(f"\n{'!'*60}")
        print(f"! CHECKPOINT BLOCKED: Batch {batch_num} failed V12 validation")
        print(f"! Fix the failing episodes and try again")
        print(f"{'!'*60}\n")
        sys.exit(1)

    # STEP 2: Run quality gate (dramatic checks)
    print(f"\nStep 2: Running quality gate (dramatic checks)...")

    if not run_quality_gate(project_path, batch_num):
        print(f"\n{'!'*60}")
        print(f"! CHECKPOINT BLOCKED: Batch {batch_num} failed quality gate")
        print(f"! Regenerate the failing episodes and try again")
        print(f"{'!'*60}\n")
        sys.exit(1)

    # STEP 3: Run DRAMATIC QC GATE (soft gate - warns but doesn't block)
    print(f"\nStep 3: Running dramatic QC gate (voice, texture, behavioral DNA)...")

    has_qc_issues, qc_issue_count = run_dramatic_qc_gate(project_path, batch_num)

    if has_qc_issues:
        print(f"\n{'*'*60}")
        print(f"* DRAMATIC QC: {qc_issue_count} issue(s) flagged (soft gate)")
        print(f"* These are quality suggestions, not blocking errors.")
        print(f"* Consider running /rewrite to address flagged issues.")
        print(f"{'*'*60}")

    # STEP 3.5: Voice contamination check (batches 3, 6, 9, 12)
    has_issues, issue_details = run_baseline_comparison(project_path, batch_num)
    if has_issues:
        print(f"\n{'*'*60}")
        print(f"* VOICE CONTAMINATION DETECTED")
        print(f"* Run qualitative review before continuing:")
        print(f"* /dramatic-qc {project_path.name} --mode post --batch {batch_num} --lens voice")
        print(f"{'*'*60}")
        # Note: Soft gate - doesn't block, but strongly recommends review

    # STEP 4: Run FULL TRANSITION GATE (only after batch 12)
    needs_review = False
    if batch_num == 12:
        print(f"\nStep 4: Running FULL TRANSITION GATE (final validation)...")
        print(f"        Validating all 59 transitions across 60 episodes...")

        transition_passed, needs_review = run_transition_gate(project_path)

        if not transition_passed:
            print(f"\n{'!'*60}")
            print(f"! CHECKPOINT BLOCKED: Full transition gate FAILED")
            print(f"! Fix the transition issues before completing the series")
            print(f"! Run: python3 .claude/hooks/transition_gate.py {project_path} --fix-list")
            print(f"{'!'*60}\n")
            sys.exit(1)

        if needs_review:
            print(f"\n{'*'*60}")
            print(f"* TRANSITION GATE: Passed with reviews needed")
            print(f"* AI must verify flagged transitions are intentional")
            print(f"{'*'*60}")

    # STEP 4.5: Generate batch summary (for orchestrator tracking)
    summary_ok = run_batch_summary_generator(project_path, batch_num)
    if not summary_ok:
        print("WARNING: Batch summary generation failed — orchestrator state may be incomplete")

    # STEP 4.6: Update orchestrator state (if in orchestrated mode)
    # Snapshot pre-update state for rollback if checkpoint save fails later
    state_file = ProjectPaths.from_root(project_path).state_dir / "current_state.json"
    _pre_checkpoint_state = None
    if state_file.exists():
        with open(state_file, 'r') as f:
            _pre_checkpoint_state = f.read()

    orch_ok = run_orchestrator_state_update(project_path, batch_num)
    if not orch_ok:
        print("WARNING: Orchestrator state update failed — state may be out of sync")

    # STEP 5: All gates passed - save checkpoint
    step_num = 5 if batch_num == 12 else 4
    print(f"\nStep {step_num}: Saving checkpoint...")

    try:
        checkpoint_dir = ProjectPaths.from_root(project_path).checkpoints_dir
        checkpoint_dir.mkdir(parents=True, exist_ok=True)

        # Read current state
        state = {}
        if state_file.exists():
            with open(state_file, 'r') as f:
                try:
                    state = json.load(f)
                except json.JSONDecodeError as e:
                    print(f"ERROR: Corrupted state file: {e}")
                    sys.exit(1)

        # Create checkpoint
        ep_start = (batch_num - 1) * GENERATION_BATCH_SIZE + 1
        ep_end = batch_num * GENERATION_BATCH_SIZE

        checkpoint = {
            "batch": batch_num,
            "episodes": f"{ep_start}-{ep_end}",
            "timestamp": datetime.now().isoformat(),
            "project": project_path.name,
            "validated": True,
            "dramatic_qc_issues": qc_issue_count if has_qc_issues else 0,
            "state_snapshot": state
        }

        checkpoint_file = checkpoint_dir / f"batch_{batch_num:02d}_checkpoint.json"
        with open(checkpoint_file, 'w') as f:
            json.dump(checkpoint, f, indent=2)

        # STEP 6: Update state file
        step_num = 6 if batch_num == 12 else 5
        print(f"Step {step_num}: Updating state file...")

        state["last_completed_batch"] = batch_num
        state["last_completed_episode"] = ep_end
        state["next_batch"] = batch_num + 1
        state.setdefault("generation", {})
        state["generation"]["last_validated"] = ep_end
        state["generation"]["last_checkpoint"] = datetime.now().isoformat()
        state["generation"]["validation_passed"] = True

        with open(state_file, 'w') as f:
            json.dump(state, f, indent=2)
    except Exception as e:
        # Rollback: restore pre-checkpoint state if save failed
        print(f"\nERROR: Checkpoint save failed: {e}")
        if _pre_checkpoint_state is not None:
            print("  Rolling back state file to pre-checkpoint snapshot...")
            with open(state_file, 'w') as f:
                f.write(_pre_checkpoint_state)
        raise

    # Success message
    print(f"\n{'='*60}")
    print(f"CHECKPOINT SAVED: Batch {batch_num} (Episodes {ep_start}-{ep_end})")
    print(f"Validation: PASSED")

    if batch_num == 12:
        # Series complete!
        print(f"\n{'*'*60}")
        print(f"*  SERIES GENERATION COMPLETE!")
        print(f"*  All 60 episodes validated and saved.")
        print(f"*  All transitions verified.")
        print(f"{'*'*60}")
        print(f"\nNext steps:")
        print(f"  1. Review any flagged transitions (if any)")
        print(f"  2. Compile the series: /compile {project_path.name}")
        print(f"  3. Run final quality review")
    else:
        print(f"Next batch: {batch_num + 1} is now UNLOCKED")
        print(f"{'='*60}\n")
        print(f"To continue generation:")
        print(f"  1. Read last 2 episodes for continuity")
        print(f"  2. Generate batch {batch_num + 1} (episodes {ep_end + 1}-{ep_end + GENERATION_BATCH_SIZE})")
        print(f"  3. Run: python3 save_checkpoint.py {project_path.name} {batch_num + 1}")

    print(f"{'='*60}\n")

if __name__ == "__main__":
    main()
