#!/usr/bin/python3
"""
Verify Thread Continuity

Checks narrative thread health across the series:
1. All planted threads eventually advance
2. Threads don't go stale (no mention for 15+ episodes)
3. Payoffs happen within target range
4. No orphan threads at series end

Usage: python3 verify_thread_continuity.py <project_path>
Example: python3 verify_thread_continuity.py ./leviathan

Returns:
- Exit code 0: All threads healthy
- Exit code 1: Issues found (details printed)
"""

import json
import sys
from pathlib import Path

# Add engine tools to path
_SCRIPT_DIR = Path(__file__).parent.resolve()
sys.path.insert(0, str(_SCRIPT_DIR))
sys.path.insert(0, str(_SCRIPT_DIR.parent.parent))  # CLAUDE_PROJECTS, for recoil.*
from recoil.core.paths import ProjectPaths

# Import from single source of truth
from engine_constants import STALE_THRESHOLD, TOTAL_EPISODES, load_format_constants


def load_orchestrator_state(project_path: Path) -> dict | None:
    """Load orchestrator state if exists."""
    state_file = ProjectPaths.from_root(project_path).state_dir / "orchestrator_state.json"
    if not state_file.exists():
        return None

    try:
        with open(state_file, 'r') as f:
            return json.load(f)
    except json.JSONDecodeError as e:
        print(f"ERROR: orchestrator_state.json is malformed: {e}")
        print(f"  File: {state_file}")
        print(f"  Fix the JSON syntax or delete the file to reset state.")
        return None


def verify_threads(state: dict, format_name: str = None) -> list:
    """
    Verify thread continuity.
    Returns list of issues found.

    Args:
        state: Orchestrator state dict.
        format_name: Optional format name to load format-specific thresholds.
                     Falls back to engine-level defaults if not provided.
    """
    # Load format-specific thresholds or use engine defaults
    if format_name:
        fmt_constants = load_format_constants(format_name)
        stale_threshold = fmt_constants.get('STALE_THRESHOLD', STALE_THRESHOLD)
        min_thread_count = fmt_constants.get('MIN_THREAD_COUNT', 6)
        total_episodes = fmt_constants.get('TOTAL_EPISODES', TOTAL_EPISODES)
    else:
        stale_threshold = STALE_THRESHOLD
        min_thread_count = 6
        total_episodes = TOTAL_EPISODES

    issues = []
    tracker = state.get("thread_tracker", {})
    position = state.get("position", {})
    current_episode = position.get("last_completed_episode", 0)

    # Check: Minimum thread count
    active_threads = sum(1 for t in tracker.values() if t.get("status") in ["planted", "advancing"])
    if active_threads < min_thread_count and current_episode > 0:
        issues.append({
            "type": "low_thread_count",
            "count": active_threads,
            "minimum": min_thread_count,
            "severity": "warning",
            "recommendation": f"Only {active_threads} active threads (minimum {min_thread_count}). Consider planting more narrative threads."
        })

    for thread_id, thread in tracker.items():
        status = thread.get("status", "pending")
        target_payoff = thread.get("target_payoff", 60)
        planted_ep = thread.get("planted_episode")
        advanced_eps = thread.get("advanced_episodes", [])
        payoff_ep = thread.get("payoff_episode")

        # Check 1: Overdue payoff
        if status in ["planted", "advancing"] and current_episode > target_payoff:
            issues.append({
                "type": "overdue_payoff",
                "thread": thread_id,
                "target": target_payoff,
                "current": current_episode,
                "severity": "warning",
                "recommendation": f"Thread '{thread_id}' is overdue for payoff. Target was Ep {target_payoff}, now at Ep {current_episode}."
            })

        # Check 2: Stale thread (planted but not advanced for threshold+ episodes)
        if status in ["planted", "advancing"] and planted_ep:
            last_mention = max([planted_ep] + advanced_eps) if advanced_eps else planted_ep
            episodes_since = current_episode - last_mention

            if episodes_since >= stale_threshold:
                issues.append({
                    "type": "stale_thread",
                    "thread": thread_id,
                    "last_mention": last_mention,
                    "episodes_since": episodes_since,
                    "severity": "warning",
                    "recommendation": f"Thread '{thread_id}' is stale. Last mention Ep {last_mention}, {episodes_since} episodes ago. Consider advancing or paying off."
                })

        # Check 3: Orphan thread at series end
        if current_episode >= total_episodes and status != "paid_off":
            if status in ["planted", "advancing"]:
                issues.append({
                    "type": "orphan_thread",
                    "thread": thread_id,
                    "status": status,
                    "severity": "error",
                    "recommendation": f"Thread '{thread_id}' was never paid off. Status: {status}. This is an unresolved plot point."
                })

        # Check 4: Thread planted but never advanced
        if status == "planted" and current_episode > (planted_ep or 0) + 10:
            issues.append({
                "type": "static_thread",
                "thread": thread_id,
                "planted_at": planted_ep,
                "severity": "info",
                "recommendation": f"Thread '{thread_id}' planted at Ep {planted_ep} but never advanced. Consider developing or removing."
            })

    return issues


def main():
    if len(sys.argv) < 2:
        print("Usage: python3 verify_thread_continuity.py <project_path> [--format FORMAT]")
        print("Example: python3 verify_thread_continuity.py ./leviathan")
        print("         python3 verify_thread_continuity.py ./leviathan --format puzzle_box")
        sys.exit(1)

    # Extract --format if present
    format_name = None
    args = list(sys.argv)
    if '--format' in args:
        fmt_idx = args.index('--format')
        if fmt_idx + 1 < len(args):
            format_name = args[fmt_idx + 1]
            args.pop(fmt_idx)  # remove --format
            args.pop(fmt_idx)  # remove the value

    project_path = Path(args[1]).resolve()

    if not project_path.exists():
        print(f"Error: Project path does not exist: {project_path}")
        sys.exit(1)

    state = load_orchestrator_state(project_path)
    if not state:
        print(f"Error: No orchestrator_state.json found. Run init_orchestrator_state.py first.")
        sys.exit(1)

    position = state.get("position", {})
    current_ep = position.get("last_completed_episode", 0)
    tracker = state.get("thread_tracker", {})

    print(f"\n{'='*60}")
    print(f"THREAD CONTINUITY CHECK")
    print(f"{'='*60}")
    print(f"\nProject: {project_path.name}")
    if format_name:
        print(f"Format: {format_name}")
    print(f"Current position: Episode {current_ep}")
    print(f"Threads tracked: {len(tracker)}")

    issues = verify_threads(state, format_name=format_name)

    # Summary by thread
    print(f"\n{'─'*60}")
    print(f"THREAD STATUS:")
    print(f"{'─'*60}")

    for thread_id, thread in tracker.items():
        status = thread.get("status", "pending")
        planted = thread.get("planted_episode", "—")
        payoff = thread.get("payoff_episode", "—")
        target = thread.get("target_payoff", 60)

        status_icon = {
            "pending": "○",
            "planted": "◐",
            "advancing": "◑",
            "paid_off": "●"
        }.get(status, "?")

        if status == "paid_off":
            print(f"  {status_icon} {thread_id}: planted Ep {planted} → paid off Ep {payoff}")
        elif status in ["planted", "advancing"]:
            print(f"  {status_icon} {thread_id}: planted Ep {planted}, target Ep {target} ({status})")
        else:
            print(f"  {status_icon} {thread_id}: pending (target Ep {target})")

    # Issues
    if issues:
        print(f"\n{'─'*60}")
        print(f"ISSUES FOUND: {len(issues)}")
        print(f"{'─'*60}")

        for issue in issues:
            severity_icon = {
                "error": "✗",
                "warning": "⚠",
                "info": "ℹ"
            }.get(issue["severity"], "?")

            print(f"\n  {severity_icon} [{issue['type'].upper()}] {issue['thread']}")
            print(f"    {issue['recommendation']}")

        print(f"\n{'='*60}")

        # Return non-zero if any errors (not just warnings)
        has_errors = any(i["severity"] == "error" for i in issues)
        sys.exit(1 if has_errors else 0)

    else:
        print(f"\n{'─'*60}")
        print(f"✓ All threads healthy")
        print(f"{'='*60}\n")
        sys.exit(0)


if __name__ == "__main__":
    main()
