#!/usr/bin/python3
"""
Update Orchestrator State

Updates orchestrator_state.json after each batch completes by:
1. Merging batch summary data
2. Updating thread tracker status
3. Updating emotional beat status
4. Updating pattern state
5. Running cross-batch verification
6. Flagging any issues

Usage: python3 update_orchestrator_state.py <project_path> <batch_number>
Example: python3 update_orchestrator_state.py ./leviathan 3

Reads: ./[project]/state/batch_N_summary.json
Updates: ./[project]/state/orchestrator_state.json
"""

import json
import sys
from datetime import datetime
from pathlib import Path

# Add engine tools to path for imports
_SCRIPT_DIR = Path(__file__).parent.resolve()
sys.path.insert(0, str(_SCRIPT_DIR))
sys.path.insert(0, str(_SCRIPT_DIR.parent.parent))  # CLAUDE_PROJECTS, for recoil.*
from recoil.core.paths import ProjectPaths

# Goal-backward targets from CONSTANTS.md
GOAL_BACKWARD_TARGETS = {
    3: {"episode": 15, "threads_resolved": "0-1", "arc_progress": "25%", "emotional_beats": "2/11"},
    6: {"episode": 30, "threads_resolved": "2-3", "arc_progress": "50%", "emotional_beats": "5/11"},
    9: {"episode": 45, "threads_resolved": "4-5", "arc_progress": "75%", "emotional_beats": "8/11"},
    12: {"episode": 60, "threads_resolved": "6+", "arc_progress": "100%", "emotional_beats": "11/11"}
}

try:
    from engine_constants import MAX_CONSECUTIVE_SAME_TYPE
except ImportError:
    MAX_CONSECUTIVE_SAME_TYPE = 3


def load_orchestrator_state(project_path: Path) -> dict:
    """Load existing orchestrator state."""
    state_file = ProjectPaths.from_root(project_path).state_dir / "orchestrator_state.json"
    if not state_file.exists():
        print(f"Error: Orchestrator state not found. Run init_orchestrator_state.py first.")
        sys.exit(1)

    try:
        with open(state_file, 'r') as f:
            return json.load(f)
    except json.JSONDecodeError as e:
        print(f"Error: Corrupted orchestrator state JSON: {e}")
        sys.exit(1)


def load_batch_summary(project_path: Path, batch_num: int) -> dict:
    """Load batch summary from completed batch."""
    summary_file = ProjectPaths.from_root(project_path).state_dir / f"batch_{batch_num:02d}_summary.json"
    if not summary_file.exists():
        print(f"Error: Batch summary not found: {summary_file}")
        sys.exit(1)

    try:
        with open(summary_file, 'r') as f:
            return json.load(f)
    except json.JSONDecodeError as e:
        print(f"Error: Corrupted batch summary JSON: {e}")
        sys.exit(1)


def update_thread_tracker(state: dict, batch_summary: dict) -> list:
    """
    Update thread tracker based on batch summary.
    Returns list of any overdue thread warnings.
    """
    warnings = []
    tracker = state["thread_tracker"]
    batch_num = batch_summary["batch"]
    current_episode = batch_num * 5

    # Process each episode's thread activity
    for ep_data in batch_summary["episodes"]:
        ep_num = ep_data["number"]
        threads = ep_data.get("threads", {})

        # Threads planted
        for thread_id in threads.get("planted", []):
            if thread_id in tracker:
                tracker[thread_id]["status"] = "planted"
                tracker[thread_id]["planted_episode"] = ep_num

        # Threads advanced
        for thread_id in threads.get("advanced", []):
            if thread_id in tracker:
                if tracker[thread_id]["status"] == "pending":
                    # Was implicitly planted
                    tracker[thread_id]["status"] = "advancing"
                    if not tracker[thread_id]["planted_episode"]:
                        tracker[thread_id]["planted_episode"] = ep_num
                else:
                    tracker[thread_id]["status"] = "advancing"

                if ep_num not in tracker[thread_id]["advanced_episodes"]:
                    tracker[thread_id]["advanced_episodes"].append(ep_num)

        # Threads paid off
        for thread_id in threads.get("paid_off", []):
            if thread_id in tracker:
                tracker[thread_id]["status"] = "paid_off"
                tracker[thread_id]["payoff_episode"] = ep_num

    # Check for overdue threads
    for thread_id, thread in tracker.items():
        if thread["status"] in ["pending", "planted", "advancing"]:
            target_payoff = thread.get("target_payoff", 60)
            # Warning if we're past expected payoff and thread not resolved
            if current_episode > target_payoff:
                warnings.append({
                    "thread_id": thread_id,
                    "expected_by": target_payoff,
                    "current_episode": current_episode
                })

    return warnings


def update_emotional_beats(state: dict, batch_summary: dict) -> None:
    """Update emotional beat map based on batch summary."""
    beat_map = state["emotional_beat_map"]

    for ep_data in batch_summary["episodes"]:
        ep_num = ep_data["number"]
        beat = ep_data.get("emotional_beat")

        if beat and beat in beat_map:
            beat_map[beat]["status"] = "hit"
            beat_map[beat]["actual_episode"] = ep_num


def update_pattern_state(state: dict, batch_summary: dict) -> list:
    """
    Update pattern state (hooks/cliffhangers) and check for violations.
    Returns list of pattern violations.
    """
    violations = []
    pattern = state["pattern_state"]
    batch_num = batch_summary["batch"]

    for ep_data in batch_summary["episodes"]:
        hook_type = ep_data.get("hook_type", "silent")
        cliffhanger_type = ep_data.get("cliffhanger_type", "mid-action")

        # Update hook counts
        if hook_type == "silent":
            pattern["hooks"]["silent_count"] += 1
            pattern["hooks"]["consecutive_silent"] += 1
            pattern["hooks"]["consecutive_dialogue"] = 0
        else:
            pattern["hooks"]["dialogue_count"] += 1
            pattern["hooks"]["consecutive_dialogue"] += 1
            pattern["hooks"]["consecutive_silent"] = 0

        pattern["hooks"]["history"].append(hook_type)

        # Check for hook pattern violation (4+ consecutive)
        if pattern["hooks"]["consecutive_silent"] > MAX_CONSECUTIVE_SAME_TYPE:
            violations.append({
                "batch": batch_num,
                "type": "hook",
                "consecutive_count": pattern["hooks"]["consecutive_silent"],
                "pattern": "silent"
            })
        if pattern["hooks"]["consecutive_dialogue"] > MAX_CONSECUTIVE_SAME_TYPE:
            violations.append({
                "batch": batch_num,
                "type": "hook",
                "consecutive_count": pattern["hooks"]["consecutive_dialogue"],
                "pattern": "dialogue"
            })

        # Update cliffhanger counts
        if cliffhanger_type == "mid-action":
            pattern["cliffhangers"]["mid_action_count"] += 1
            pattern["cliffhangers"]["consecutive_mid_action"] += 1
            pattern["cliffhangers"]["consecutive_aftermath"] = 0
        else:
            pattern["cliffhangers"]["aftermath_count"] += 1
            pattern["cliffhangers"]["consecutive_aftermath"] += 1
            pattern["cliffhangers"]["consecutive_mid_action"] = 0

        pattern["cliffhangers"]["history"].append(cliffhanger_type)

        # Check for cliffhanger pattern violation (4+ consecutive, but not in pilot)
        current_episode = ep_data["number"]
        if current_episode > 10:  # Pattern variety kicks in after pilot
            if pattern["cliffhangers"]["consecutive_mid_action"] > MAX_CONSECUTIVE_SAME_TYPE:
                violations.append({
                    "batch": batch_num,
                    "type": "cliffhanger",
                    "consecutive_count": pattern["cliffhangers"]["consecutive_mid_action"],
                    "pattern": "mid-action"
                })
            if pattern["cliffhangers"]["consecutive_aftermath"] > MAX_CONSECUTIVE_SAME_TYPE:
                violations.append({
                    "batch": batch_num,
                    "type": "cliffhanger",
                    "consecutive_count": pattern["cliffhangers"]["consecutive_aftermath"],
                    "pattern": "aftermath"
                })

    return violations


def run_goal_backward_check(state: dict, batch_num: int) -> dict | None:
    """
    Run goal-backward verification at checkpoint batches (3, 6, 9, 12).
    Returns checkpoint data or None if not a checkpoint batch.
    """
    if batch_num not in GOAL_BACKWARD_TARGETS:
        return None

    targets = GOAL_BACKWARD_TARGETS[batch_num]

    # Count threads resolved
    threads_resolved = sum(
        1 for t in state["thread_tracker"].values()
        if t["status"] == "paid_off"
    )

    # Count emotional beats hit
    beats_hit = sum(
        1 for b in state["emotional_beat_map"].values()
        if b["status"] == "hit"
    )

    # Calculate arc progress
    current_episode = batch_num * 5
    arc_progress = f"{round(current_episode / 60 * 100)}%"

    # Determine status
    expected_beats = int(targets["emotional_beats"].split("/")[0])
    status = "on_track"
    course_corrections = []

    # Parse expected threads range
    threads_range = targets["threads_resolved"]
    if "+" in threads_range:
        min_threads = int(threads_range.replace("+", ""))
        if threads_resolved < min_threads:
            status = "behind"
            course_corrections.append(f"Accelerate thread payoffs - {min_threads - threads_resolved} threads behind target")
    elif "-" in threads_range:
        parts = threads_range.split("-")
        min_threads, max_threads = int(parts[0]), int(parts[1])
        if threads_resolved < min_threads:
            status = "behind"
            course_corrections.append(f"Accelerate thread payoffs in next batch")
    else:
        min_threads = int(threads_range)
        if threads_resolved < min_threads:
            status = "behind"

    # Check beats
    if beats_hit < expected_beats:
        if status == "on_track":
            status = "behind"
        course_corrections.append(f"Insert recovery beat - {expected_beats - beats_hit} beats behind schedule")

    checkpoint = {
        "episode": current_episode,
        "threads_resolved_expected": threads_range,
        "threads_resolved_actual": threads_resolved,
        "arc_progress_expected": targets["arc_progress"],
        "arc_progress_actual": arc_progress,
        "emotional_beats_expected": targets["emotional_beats"],
        "emotional_beats_actual": beats_hit,
        "status": status,
        "course_corrections": course_corrections
    }

    return checkpoint


def update_position(state: dict, batch_num: int) -> None:
    """Update position tracking."""
    state["position"]["last_completed_batch"] = batch_num
    state["position"]["last_completed_episode"] = batch_num * 5
    state["position"]["next_batch"] = batch_num + 1


def save_orchestrator_state(state: dict, project_path: Path) -> None:
    """Save updated orchestrator state."""
    state["meta"]["last_updated"] = datetime.now().isoformat()

    output_path = ProjectPaths.from_root(project_path).state_dir / "orchestrator_state.json"
    with open(output_path, 'w') as f:
        json.dump(state, f, indent=2)


def main():
    if len(sys.argv) < 3:
        print("Usage: python3 update_orchestrator_state.py <project_path> <batch_number>")
        print("Example: python3 update_orchestrator_state.py ./leviathan 3")
        sys.exit(1)

    project_path = Path(sys.argv[1]).resolve()
    batch_num = int(sys.argv[2])

    if not project_path.exists():
        print(f"Error: Project path does not exist: {project_path}")
        sys.exit(1)

    # Load state and batch summary
    state = load_orchestrator_state(project_path)
    batch_summary = load_batch_summary(project_path, batch_num)

    print(f"\n{'='*60}")
    print(f"UPDATING ORCHESTRATOR STATE: Batch {batch_num}")
    print(f"{'='*60}")

    # Update components
    overdue_threads = update_thread_tracker(state, batch_summary)
    update_emotional_beats(state, batch_summary)
    pattern_violations = update_pattern_state(state, batch_summary)

    # Store batch summary reference
    state["batch_summaries"][f"batch_{batch_num}"] = {
        "batch": batch_num,
        "episodes": f"{(batch_num-1)*5+1}-{batch_num*5}",
        "threads_planted": batch_summary["batch_summary"].get("threads_planted_this_batch", []),
        "threads_advanced": batch_summary["batch_summary"].get("threads_advanced_this_batch", []),
        "threads_paid_off": batch_summary["batch_summary"].get("threads_paid_off_this_batch", []),
        "emotional_beats_hit": batch_summary["batch_summary"].get("emotional_beats_hit_this_batch", []),
        "hook_distribution": batch_summary["batch_summary"]["pattern_distribution"]["hooks"],
        "cliffhanger_distribution": batch_summary["batch_summary"]["pattern_distribution"]["cliffhangers"]
    }

    # Update cross-batch flags
    if overdue_threads:
        state["cross_batch_flags"]["overdue_threads"].extend(overdue_threads)
        print(f"\nWARNING: {len(overdue_threads)} overdue thread(s) detected:")
        for t in overdue_threads:
            print(f"  - {t['thread_id']}: expected by Ep {t['expected_by']}, now at Ep {t['current_episode']}")

    if pattern_violations:
        state["cross_batch_flags"]["pattern_violations"].extend(pattern_violations)
        print(f"\nWARNING: {len(pattern_violations)} pattern violation(s) detected:")
        for v in pattern_violations:
            print(f"  - {v['type']}: {v['consecutive_count']} consecutive {v['pattern']}")

    # Run goal-backward check at checkpoints
    checkpoint = run_goal_backward_check(state, batch_num)
    if checkpoint:
        state["goal_backward_checkpoints"][f"batch_{batch_num}"] = checkpoint
        print(f"\nGOAL-BACKWARD CHECK (Batch {batch_num}):")
        print(f"  Status: {checkpoint['status'].upper()}")
        print(f"  Threads: {checkpoint['threads_resolved_actual']} / {checkpoint['threads_resolved_expected']}")
        print(f"  Beats: {checkpoint['emotional_beats_actual']} / {checkpoint['emotional_beats_expected']}")
        print(f"  Arc: {checkpoint['arc_progress_actual']} / {checkpoint['arc_progress_expected']}")
        if checkpoint["course_corrections"]:
            print(f"  Course corrections needed:")
            for cc in checkpoint["course_corrections"]:
                print(f"    - {cc}")

    # Update position
    update_position(state, batch_num)

    # Save state
    save_orchestrator_state(state, project_path)

    # Summary
    total_episodes = batch_num * 5
    hooks = state["pattern_state"]["hooks"]
    cliffs = state["pattern_state"]["cliffhangers"]

    print(f"\n{'='*60}")
    print(f"STATE UPDATED")
    print(f"{'='*60}")
    print(f"Position: Batch {batch_num} complete ({total_episodes}/60 episodes)")
    print(f"Next batch: {batch_num + 1}")
    print(f"\nPattern Distribution:")
    print(f"  Hooks: {hooks['silent_count']} silent / {hooks['dialogue_count']} dialogue")
    if total_episodes > 0:
        print(f"         ({round(hooks['silent_count']/total_episodes*100)}% silent)")
    print(f"  Cliffhangers: {cliffs['mid_action_count']} mid-action / {cliffs['aftermath_count']} aftermath")
    if total_episodes > 0:
        print(f"               ({round(cliffs['mid_action_count']/total_episodes*100)}% mid-action)")

    threads_paid = sum(1 for t in state["thread_tracker"].values() if t["status"] == "paid_off")
    threads_active = sum(1 for t in state["thread_tracker"].values() if t["status"] in ["planted", "advancing"])
    beats_hit = sum(1 for b in state["emotional_beat_map"].values() if b["status"] == "hit")

    print(f"\nArc Progress:")
    print(f"  Threads: {threads_paid} paid off, {threads_active} active")
    print(f"  Emotional beats: {beats_hit}/11 hit")
    print(f"\n{'='*60}\n")


if __name__ == "__main__":
    main()
