#!/usr/bin/python3
"""
Initialize Orchestrator State

Creates the orchestrator_state.json file for a project by extracting:
1. Thread tracker from treatment.md THREAD INDEX
2. Emotional beat schedule from format-specific CONSTANTS.md
3. Format-specific state (pattern_state, resonance, rhythm, etc.)
4. Position tracking

Format-aware: reads project_config.json to determine format, defaults to kill_box.

Usage: python3 init_orchestrator_state.py <project_path>
Example: python3 init_orchestrator_state.py ./leviathan

Output: ./[project]/state/orchestrator_state.json
"""


import json
import re
import sys
from datetime import datetime
from pathlib import Path
from typing import Optional, List, Dict, Any

# Add engine tools to path for imports
_SCRIPT_DIR = Path(__file__).parent.resolve()
_ENGINE_ROOT = _SCRIPT_DIR.parent
sys.path.insert(0, str(_SCRIPT_DIR))
sys.path.insert(0, str(_SCRIPT_DIR.parent.parent))  # CLAUDE_PROJECTS, for recoil.*
from recoil.core.paths import ProjectPaths

try:
    from engine_constants import GENERATION_BATCH_SIZE, load_format_constants
except ImportError:
    GENERATION_BATCH_SIZE = 5

    def load_format_constants(format_name: str) -> dict:
        return {}

# ── Kill Box (V12) emotional beat schedule ──────────────────────────────────
# Hardcoded as the backward-compatible default when no format is specified.
EMOTIONAL_BEAT_SCHEDULE = {
    "FIRST_CRACK": {"target_episode": 10, "tolerance": 2},
    "THRESHOLD": {"target_episode": 15, "tolerance": 2},
    "DEEPENING": {"target_episode": 20, "tolerance": 2},
    "VULNERABILITY": {"target_episode": 26, "tolerance": 2},
    "MIDPOINT": {"target_episode": 30, "tolerance": 2},
    "FRACTURE": {"target_episode": 33, "tolerance": 2},
    "BETRAYAL_DOUBT": {"target_episode": 36, "tolerance": 2},
    "COST": {"target_episode": 42, "tolerance": 2},
    "ALL_IS_LOST": {"target_episode": 45, "tolerance": 2},
    "RECONCILIATION": {"target_episode": 50, "tolerance": 2},
    "RESOLUTION": {"target_episode": 60, "tolerance": 2}
}

# ── Micro-format emotional beat schedules ───────────────────────────────────
# Used by both puzzle_box and kill_box_micro (6 required beats, 16-episode arc).
MICRO_EMOTIONAL_BEAT_SCHEDULE = {
    "THRESHOLD": {"target": 3},
    "MIDPOINT": {"target": 8},
    "VULNERABILITY": {"target": 11},
    "ALL_IS_LOST": {"target": 12},
    "FRACTURE_RECONCILIATION": {"target": 15},
    "RESOLUTION": {"target": 16},
}

# ── Goal-backward checkpoints per format ────────────────────────────────────
GOAL_BACKWARD_KILL_BOX = [3, 6, 9, 12]        # batch numbers
GOAL_BACKWARD_MICRO = [4, 8, 12, 16]           # episode numbers


def extract_threads_from_treatment(treatment_path: Path) -> dict:
    """
    Extract thread information from treatment.md THREAD INDEX section.

    Expected format in treatment.md:
    ## THREAD INDEX
    | Thread | Plant | Advance | Payoff |
    |--------|-------|---------|--------|
    | THREAD_NAME | Ep X | Ep Y, Z | Ep N |
    """
    threads = {}

    if not treatment_path.exists():
        print(f"WARNING: treatment.md not found at {treatment_path}")
        return threads

    content = treatment_path.read_text()

    # Find THREAD INDEX section
    thread_section_match = re.search(
        r'##\s*THREAD\s*INDEX.*?\n(.*?)(?=\n##|\Z)',
        content,
        re.IGNORECASE | re.DOTALL
    )

    if not thread_section_match:
        print("WARNING: No THREAD INDEX section found in treatment.md")
        return threads

    thread_section = thread_section_match.group(1)

    # Parse table rows (skip header rows)
    rows = re.findall(r'\|\s*([^|]+)\s*\|\s*([^|]+)\s*\|\s*([^|]+)\s*\|\s*([^|]+)\s*\|', thread_section)

    for row in rows:
        thread_name = row[0].strip()
        plant_info = row[1].strip()
        advance_info = row[2].strip()
        payoff_info = row[3].strip()

        # Skip header rows
        if thread_name.lower() in ['thread', '---', '--------', '']:
            continue
        if '-' in thread_name and all(c in '-' for c in thread_name.replace(' ', '')):
            continue

        # Parse episode numbers
        plant_ep = extract_episode_number(plant_info)
        advance_eps = extract_episode_numbers(advance_info)
        payoff_ep = extract_episode_number(payoff_info)

        thread_id = thread_name.upper().replace(' ', '_')
        threads[thread_id] = {
            "name": thread_name,
            "status": "pending",
            "planted_episode": None,
            "advanced_episodes": [],
            "payoff_episode": None,
            "target_plant": plant_ep,
            "target_advances": advance_eps,
            "target_payoff": payoff_ep if payoff_ep else 60,
            "notes": ""
        }

    return threads


def extract_episode_number(text: str) -> Optional[int]:
    """Extract single episode number from text like 'Ep 10' or 'Episode 15'."""
    match = re.search(r'(?:Ep(?:isode)?\s*)?(\d+)', text, re.IGNORECASE)
    return int(match.group(1)) if match else None


def extract_episode_numbers(text: str) -> List[int]:
    """Extract multiple episode numbers from text like 'Ep 10, 20, 30' or 'Episodes 15-20'."""
    # Handle ranges like "15-20"
    range_match = re.search(r'(\d+)\s*-\s*(\d+)', text)
    if range_match:
        start, end = int(range_match.group(1)), int(range_match.group(2))
        return list(range(start, end + 1))

    # Handle comma-separated like "10, 20, 30"
    matches = re.findall(r'\d+', text)
    return [int(m) for m in matches]


def create_initial_emotional_beat_map() -> dict:
    """Create initial emotional beat map with all beats pending (kill_box V12)."""
    beat_map = {}
    for beat_name, config in EMOTIONAL_BEAT_SCHEDULE.items():
        beat_map[beat_name] = {
            "beat_name": beat_name,
            "target_episode": config["target_episode"],
            "tolerance": config["tolerance"],
            "status": "pending",
            "actual_episode": None
        }
    return beat_map


def create_initial_pattern_state() -> dict:
    """Create zeroed pattern state for hooks and cliffhangers (kill_box V12)."""
    return {
        "hooks": {
            "silent_count": 0,
            "dialogue_count": 0,
            "consecutive_silent": 0,
            "consecutive_dialogue": 0,
            "history": []
        },
        "cliffhangers": {
            "mid_action_count": 0,
            "aftermath_count": 0,
            "consecutive_mid_action": 0,
            "consecutive_aftermath": 0,
            "history": []
        }
    }


# ── Format detection ────────────────────────────────────────────────────────

def detect_format(project_path: Path) -> str:
    """
    Read format name from project_config.json.
    Returns 'kill_box' if file is missing or has no format field.
    """
    config_path = project_path / "project_config.json"
    if config_path.exists():
        try:
            with open(config_path) as f:
                config = json.load(f)
            fmt = config.get("format", "kill_box")
            if fmt:
                return fmt
        except (json.JSONDecodeError, OSError) as exc:
            print(f"WARNING: Could not read {config_path}: {exc}")
    return "kill_box"


# ── Format-specific state builders ──────────────────────────────────────────

def _micro_emotional_beat_map() -> dict:
    """Emotional beat map for 16-episode formats (puzzle_box, kill_box_micro)."""
    beat_map = {}
    for beat_name, config in MICRO_EMOTIONAL_BEAT_SCHEDULE.items():
        beat_map[beat_name] = {
            "target": config["target"],
            "actual": None,
        }
    return beat_map


def build_format_state_kill_box() -> dict:
    """Build format_state for kill_box (V12). Matches existing top-level structure."""
    return {
        "emotional_beat_map": create_initial_emotional_beat_map(),
        "pattern_state": create_initial_pattern_state(),
        "goal_backward_checkpoints": GOAL_BACKWARD_KILL_BOX,
    }


def build_format_state_puzzle_box() -> dict:
    """Build format_state for puzzle_box."""
    return {
        "emotional_beat_map": _micro_emotional_beat_map(),
        "resonance_state": {
            "rhyme": 0,
            "withhold": 0,
            "dissonance": 0,
            "object": 0,
            "absence": 0,
        },
        "eruption_tracker": {
            "budget": 3,
            "used": 0,
            "episodes": [],
        },
        "rhythm_distribution": {
            "suspended": 0,
            "layered": 0,
            "kinetic": 0,
            "drift": 0,
        },
        "vo_tracker": {
            "budget": 4,
            "used": 0,
            "episodes": [],
        },
        "goal_backward_checkpoints": GOAL_BACKWARD_MICRO,
    }


def build_format_state_kill_box_micro() -> dict:
    """Build format_state for kill_box_micro."""
    return {
        "emotional_beat_map": _micro_emotional_beat_map(),
        "pattern_state": {
            "cliffhangers": {
                "reveal": 0,
                "reversal": 0,
                "clock": 0,
                "dilemma": 0,
            },
            "rhythm": {
                "frenetic": 0,
                "measured": 0,
                "fluid": 0,
            },
        },
        "goal_backward_checkpoints": GOAL_BACKWARD_MICRO,
    }


# Dispatch table — add new formats here.
_FORMAT_STATE_BUILDERS: Dict[str, Any] = {
    "kill_box": build_format_state_kill_box,
    "puzzle_box": build_format_state_puzzle_box,
    "kill_box_micro": build_format_state_kill_box_micro,
}


def build_format_state(format_name: str) -> dict:
    """
    Build the format_state section for a given format name.
    Falls back to kill_box if the format is unknown.
    """
    builder = _FORMAT_STATE_BUILDERS.get(format_name)
    if builder is None:
        print(f"WARNING: Unknown format '{format_name}', falling back to kill_box")
        builder = build_format_state_kill_box
    return builder()


def create_orchestrator_state(project_path: Path) -> dict:
    """Create the full orchestrator state structure."""
    treatment_path = project_path / "treatment.md"

    # Extract threads from treatment
    threads = extract_threads_from_treatment(treatment_path)

    # ── Detect format ───────────────────────────────────────────────────
    format_name = detect_format(project_path)
    format_constants = load_format_constants(format_name)
    batch_size = format_constants.get("GENERATION_BATCH_SIZE", GENERATION_BATCH_SIZE)

    # ── Build format-specific state ─────────────────────────────────────
    format_state = build_format_state(format_name)

    state = {
        "meta": {
            "project": project_path.name,
            "created": datetime.now().isoformat(),
            "last_updated": datetime.now().isoformat(),
            "engine_version": "V12"
        },
        "format": format_name,
        "thread_tracker": threads,
        "format_state": format_state,
        "position": {
            "last_completed_batch": 0,
            "last_completed_episode": 0,
            "next_batch": 1
        },
        "cross_batch_flags": {
            "voice_concerns": [],
            "continuity_breaks": [],
            "overdue_threads": [],
            "pattern_violations": []
        },
        "batch_summaries": {},
    }

    # Detect existing episodes and adjust position to match filesystem
    episodes_dir = project_path / "episodes"
    checkpoint_dir = ProjectPaths.from_root(project_path).checkpoints_dir

    if episodes_dir.exists():
        existing = sorted(episodes_dir.glob("ep_*.md"))
        if existing:
            highest_ep = int(existing[-1].stem.split('_')[1])
            highest_batch = ((highest_ep - 1) // batch_size) + 1

            # Verify with checkpoints — only trust contiguous checkpointed batches
            verified_batch = 0
            for b in range(1, highest_batch + 1):
                cp_found = False
                if checkpoint_dir.exists():
                    for f in checkpoint_dir.iterdir():
                        if f"batch_{b:02d}" in f.name:
                            cp_found = True
                            break
                if cp_found:
                    verified_batch = b
                else:
                    break  # Stop at first missing checkpoint

            if verified_batch > 0:
                state["position"]["last_completed_batch"] = verified_batch
                state["position"]["last_completed_episode"] = verified_batch * batch_size
                state["position"]["next_batch"] = verified_batch + 1
                print(f"DETECTED: {len(existing)} episodes exist, {verified_batch} batches checkpointed")
                print(f"SET: next_batch = {verified_batch + 1}")

    return state


def main():
    if len(sys.argv) < 2:
        print("Usage: python3 init_orchestrator_state.py <project_path>")
        print("Example: python3 init_orchestrator_state.py ./leviathan")
        sys.exit(1)

    project_path = Path(sys.argv[1]).resolve()

    if not project_path.exists():
        print(f"Error: Project path does not exist: {project_path}")
        sys.exit(1)

    # Ensure state directory exists
    state_dir = ProjectPaths.from_root(project_path).state_dir
    state_dir.mkdir(parents=True, exist_ok=True)

    # Create orchestrator state
    state = create_orchestrator_state(project_path)

    # Write to file
    output_path = state_dir / "orchestrator_state.json"
    with open(output_path, 'w') as f:
        json.dump(state, f, indent=2)

    fmt = state.get("format", "kill_box")
    fmt_state = state.get("format_state", {})
    beat_count = len(fmt_state.get("emotional_beat_map", {}))

    print(f"\n{'='*60}")
    print("ORCHESTRATOR STATE INITIALIZED")
    print(f"{'='*60}")
    print(f"\nProject: {project_path.name}")
    print(f"Format:  {fmt}")
    print(f"Output:  {output_path}")
    print(f"\nThreads loaded: {len(state['thread_tracker'])}")
    for thread_id in state['thread_tracker']:
        thread = state['thread_tracker'][thread_id]
        print(f"  - {thread['name']}: plant Ep {thread.get('target_plant', '?')}, payoff Ep {thread['target_payoff']}")

    print(f"\nEmotional beats: {beat_count}")
    print(f"Format state sections: {', '.join(fmt_state.keys())}")
    next_batch = state['position']['next_batch']
    if next_batch > 1:
        print(f"Position: Resuming at Batch {next_batch} ({state['position']['last_completed_episode']} episodes checkpointed)")
    else:
        print("Position: Ready for Batch 1")
    print(f"\n{'='*60}\n")


if __name__ == "__main__":
    main()
