#!/usr/bin/env python3
"""
Thread Tracking Tool
Tracks plant/payoff thread status during episode generation.

Usage:
    python track_threads.py <project_path> status
    python track_threads.py <project_path> plant <thread_name> <episode>
    python track_threads.py <project_path> payoff <thread_name> <episode>
    python track_threads.py <project_path> check <episode>

Examples:
    python track_threads.py ./leviathan status
    python track_threads.py ./leviathan plant "EXPENDABLE" 1
    python track_threads.py ./leviathan payoff "EXPENDABLE" 56
    python track_threads.py ./leviathan check 30
"""

import sys
import json
import re
from pathlib import Path
from datetime import datetime

# Add engine tools to path
_SCRIPT_DIR = Path(__file__).parent.resolve()
sys.path.insert(0, str(_SCRIPT_DIR))

from engine_constants import load_format_constants

# ANSI colors
GREEN = '\033[92m'
RED = '\033[91m'
YELLOW = '\033[93m'
CYAN = '\033[96m'
BOLD = '\033[1m'
RESET = '\033[0m'


def load_state(project_path):
    """Load current_state.json."""
    state_path = project_path / 'state' / 'current_state.json'
    if not state_path.exists():
        return None

    with open(state_path, 'r', encoding='utf-8') as f:
        try:
            return json.load(f)
        except json.JSONDecodeError as e:
            print(f"{RED}Error: Corrupted state file: {e}{RESET}")
            return None


def save_state(project_path, state):
    """Save current_state.json."""
    state_path = project_path / 'state' / 'current_state.json'
    state['last_updated'] = datetime.now().isoformat()

    with open(state_path, 'w', encoding='utf-8') as f:
        json.dump(state, f, indent=2)

    print(f"  {GREEN}✓{RESET} State saved")


def load_plant_payoff_plan(project_path):
    """Load plant/payoff plan from development or bible folder."""
    # Try production path first
    pp_path = project_path / 'bible' / 'dev_plant_payoff_plan.md'
    if not pp_path.exists():
        # Try development path
        dev_path = project_path / 'development'
        pp_path = dev_path / 'plant_payoff_plan.md'

    if not pp_path.exists():
        return None

    with open(pp_path, 'r', encoding='utf-8') as f:
        content = f.read()

    # Parse threads
    threads = []
    thread_pattern = re.compile(
        r'##\s*Thread\s*\d+[:\s]+([^\n]+)\n+\*\*Type:\*\*\s*(\w+).*?'
        r'\*\*Plant:\*\*\s*Ep(?:isode)?\s*(\d+).*?'
        r'\*\*Payoff:\*\*\s*Ep(?:isode)?\s*(\d+)',
        re.IGNORECASE | re.DOTALL
    )

    for match in thread_pattern.finditer(content):
        threads.append({
            'name': match.group(1).strip().strip('"'),
            'type': match.group(2).strip().lower(),
            'plant_episode': int(match.group(3)),
            'payoff_episode': int(match.group(4)),
            'planted': False,
            'paid_off': False
        })

    return threads


def _get_format_thresholds(format_name=None):
    """Get stale threshold and min thread count from format constants or defaults."""
    if format_name:
        constants = load_format_constants(format_name)
        return {
            'stale_threshold': constants.get('STALE_THRESHOLD', 15),
            'min_thread_count': constants.get('MIN_THREAD_COUNT', 6),
        }
    return {
        'stale_threshold': 15,
        'min_thread_count': 6,
    }


def show_status(project_path, format_name=None):
    """Show current thread status."""
    state = load_state(project_path)
    if not state:
        print(f"{RED}Error: No state file found{RESET}")
        return False

    thresholds = _get_format_thresholds(format_name)

    print(f"\n{'═' * 60}")
    print(f"{BOLD}THREAD STATUS: {project_path.name.upper()}{RESET}")
    print(f"{'═' * 60}\n")

    current_ep = state.get('last_episode', 0)
    print(f"Current Episode: {current_ep}\n")

    planted = state.get('threads_planted', [])
    paid_off = state.get('threads_paid_off', [])

    # Load planned threads
    planned_threads = load_plant_payoff_plan(project_path)

    if planned_threads:
        print(f"{BOLD}PLANNED THREADS:{RESET}\n")

        for t in planned_threads:
            name = t['name'][:30]
            plant_ep = t['plant_episode']
            payoff_ep = t['payoff_episode']

            # Check status
            is_planted = any(p.get('thread') == t['name'] or p.get('name') == t['name'] for p in planted) if isinstance(planted, list) and planted and isinstance(planted[0], dict) else t['name'] in planted
            is_paid = any(p.get('thread') == t['name'] or p.get('name') == t['name'] for p in paid_off) if isinstance(paid_off, list) and paid_off and isinstance(paid_off[0], dict) else t['name'] in paid_off

            plant_status = f"{GREEN}✓{RESET}" if is_planted else (f"{YELLOW}DUE{RESET}" if current_ep >= plant_ep else f"{CYAN}→{plant_ep}{RESET}")
            payoff_status = f"{GREEN}✓{RESET}" if is_paid else (f"{YELLOW}DUE{RESET}" if current_ep >= payoff_ep else f"{CYAN}→{payoff_ep}{RESET}")

            print(f"  {t['type'].upper()[:3]:3} | {name:30} | Plant: {plant_status:15} | Payoff: {payoff_status}")

    # Show manually tracked
    if planted:
        print(f"\n{BOLD}PLANTED:{RESET}")
        for p in planted:
            if isinstance(p, dict):
                print(f"  - {p.get('thread', p.get('name', 'Unknown'))} (Ep {p.get('episode', '?')})")
            else:
                print(f"  - {p}")

    if paid_off:
        print(f"\n{BOLD}PAID OFF:{RESET}")
        for p in paid_off:
            if isinstance(p, dict):
                print(f"  - {p.get('thread', p.get('name', 'Unknown'))} (Ep {p.get('episode', '?')})")
            else:
                print(f"  - {p}")

    # Warnings
    print(f"\n{'─' * 60}")

    if planned_threads:
        overdue_plants = [t for t in planned_threads if t['plant_episode'] <= current_ep and t['name'] not in [p.get('thread', p) if isinstance(p, dict) else p for p in planted]]
        overdue_payoffs = [t for t in planned_threads if t['payoff_episode'] <= current_ep and t['name'] not in [p.get('thread', p) if isinstance(p, dict) else p for p in paid_off]]

        if overdue_plants:
            print(f"\n{YELLOW}⚠ OVERDUE PLANTS:{RESET}")
            for t in overdue_plants:
                print(f"  - {t['name']} (was due Ep {t['plant_episode']})")

        if overdue_payoffs:
            print(f"\n{YELLOW}⚠ OVERDUE PAYOFFS:{RESET}")
            for t in overdue_payoffs:
                print(f"  - {t['name']} (was due Ep {t['payoff_episode']})")

        # Stale threads (planted but not paid off, no mention for STALE_THRESHOLD+ episodes)
        threshold = thresholds['stale_threshold']
        planted_names = {(p.get('thread', p.get('name')) if isinstance(p, dict) else p) for p in planted}
        paid_off_names = {(p.get('thread', p.get('name')) if isinstance(p, dict) else p) for p in paid_off}
        stale_threads = []
        for t in planned_threads:
            if t['name'] in planted_names and t['name'] not in paid_off_names:
                # Planted entry — find the episode it was planted
                plant_ep = None
                for p in planted:
                    pname = p.get('thread', p.get('name')) if isinstance(p, dict) else p
                    if pname == t['name']:
                        plant_ep = p.get('episode', t['plant_episode']) if isinstance(p, dict) else t['plant_episode']
                        break
                if plant_ep and current_ep - plant_ep >= threshold:
                    stale_threads.append((t['name'], plant_ep))

        if stale_threads:
            print(f"\n{YELLOW}⚠ STALE THREADS (no activity for {threshold}+ episodes):{RESET}")
            for name, ep in stale_threads:
                print(f"  - {name} (planted Ep {ep}, {current_ep - ep} eps ago)")

        # Upcoming
        upcoming_plants = [t for t in planned_threads if current_ep < t['plant_episode'] <= current_ep + 5]
        upcoming_payoffs = [t for t in planned_threads if current_ep < t['payoff_episode'] <= current_ep + 5]

        if upcoming_plants:
            print(f"\n{CYAN}UPCOMING PLANTS (next 5 eps):{RESET}")
            for t in upcoming_plants:
                print(f"  - {t['name']} → Ep {t['plant_episode']}")

        if upcoming_payoffs:
            print(f"\n{CYAN}UPCOMING PAYOFFS (next 5 eps):{RESET}")
            for t in upcoming_payoffs:
                print(f"  - {t['name']} → Ep {t['payoff_episode']}")

    print(f"\n{'═' * 60}\n")
    return True


def record_plant(project_path, thread_name, episode):
    """Record that a thread was planted."""
    state = load_state(project_path)
    if not state:
        print(f"{RED}Error: No state file found{RESET}")
        return False

    if 'threads_planted' not in state:
        state['threads_planted'] = []

    # Check if already planted
    existing = [p for p in state['threads_planted'] if (isinstance(p, dict) and p.get('thread') == thread_name) or p == thread_name]
    if existing:
        print(f"{YELLOW}Warning: Thread '{thread_name}' already planted{RESET}")
        return False

    state['threads_planted'].append({
        'thread': thread_name,
        'episode': episode,
        'timestamp': datetime.now().isoformat()
    })

    save_state(project_path, state)
    print(f"{GREEN}✓ Planted '{thread_name}' in Episode {episode}{RESET}")
    return True


def record_payoff(project_path, thread_name, episode):
    """Record that a thread was paid off."""
    state = load_state(project_path)
    if not state:
        print(f"{RED}Error: No state file found{RESET}")
        return False

    if 'threads_paid_off' not in state:
        state['threads_paid_off'] = []

    # Check if already paid off
    existing = [p for p in state['threads_paid_off'] if (isinstance(p, dict) and p.get('thread') == thread_name) or p == thread_name]
    if existing:
        print(f"{YELLOW}Warning: Thread '{thread_name}' already paid off{RESET}")
        return False

    # Check if planted first
    planted = state.get('threads_planted', [])
    was_planted = any((isinstance(p, dict) and p.get('thread') == thread_name) or p == thread_name for p in planted)

    if not was_planted:
        print(f"{YELLOW}Warning: Thread '{thread_name}' being paid off but was never planted!{RESET}")

    state['threads_paid_off'].append({
        'thread': thread_name,
        'episode': episode,
        'timestamp': datetime.now().isoformat()
    })

    save_state(project_path, state)
    print(f"{GREEN}✓ Paid off '{thread_name}' in Episode {episode}{RESET}")
    return True


def check_episode(project_path, episode):
    """Check what threads should be planted/paid off in an episode."""
    planned_threads = load_plant_payoff_plan(project_path)
    state = load_state(project_path)

    if not planned_threads:
        print(f"{YELLOW}No plant/payoff plan found{RESET}")
        return

    print(f"\n{'═' * 60}")
    print(f"{BOLD}EPISODE {episode} THREAD CHECK{RESET}")
    print(f"{'═' * 60}\n")

    planted = state.get('threads_planted', []) if state else []
    paid_off = state.get('threads_paid_off', []) if state else []

    # Plants due
    plants_due = [t for t in planned_threads if t['plant_episode'] == episode]
    if plants_due:
        print(f"{BOLD}MUST PLANT:{RESET}")
        for t in plants_due:
            is_done = any((isinstance(p, dict) and p.get('thread') == t['name']) or p == t['name'] for p in planted)
            status = f"{GREEN}✓ DONE{RESET}" if is_done else f"{RED}PENDING{RESET}"
            print(f"  - {t['name']} ({t['type']}) {status}")

    # Payoffs due
    payoffs_due = [t for t in planned_threads if t['payoff_episode'] == episode]
    if payoffs_due:
        print(f"\n{BOLD}MUST PAY OFF:{RESET}")
        for t in payoffs_due:
            is_done = any((isinstance(p, dict) and p.get('thread') == t['name']) or p == t['name'] for p in paid_off)
            status = f"{GREEN}✓ DONE{RESET}" if is_done else f"{RED}PENDING{RESET}"
            print(f"  - {t['name']} ({t['type']}) {status}")

    if not plants_due and not payoffs_due:
        print(f"{CYAN}No threads scheduled for this episode{RESET}")

    # Active threads (planted but not paid off)
    active = []
    for t in planned_threads:
        was_planted = any((isinstance(p, dict) and p.get('thread') == t['name']) or p == t['name'] for p in planted)
        was_paid = any((isinstance(p, dict) and p.get('thread') == t['name']) or p == t['name'] for p in paid_off)
        if was_planted and not was_paid:
            active.append(t)

    if active:
        print(f"\n{BOLD}ACTIVE THREADS (can reference):{RESET}")
        for t in active:
            print(f"  - {t['name']} → payoff Ep {t['payoff_episode']}")

    print(f"\n{'═' * 60}\n")


def main():
    if len(sys.argv) < 3:
        print(f"Usage: python {sys.argv[0]} <project_path> <command> [args] [--format FORMAT]")
        print(f"\nCommands:")
        print(f"  status                    - Show thread status")
        print(f"  plant <name> <episode>    - Record a plant")
        print(f"  payoff <name> <episode>   - Record a payoff")
        print(f"  check <episode>           - Check episode requirements")
        print(f"\nOptions:")
        print(f"  --format FORMAT           - Use format-specific constants (e.g., kill_box, puzzle_box)")
        sys.exit(1)

    # Extract --format if present
    format_name = None
    args = list(sys.argv)
    if '--format' in args:
        fmt_idx = args.index('--format')
        if fmt_idx + 1 < len(args):
            format_name = args[fmt_idx + 1]
            args.pop(fmt_idx)  # remove --format
            args.pop(fmt_idx)  # remove the value

    project_path = Path(args[1])
    command = args[2]

    if command == 'status':
        show_status(project_path, format_name=format_name)
    elif command == 'plant':
        if len(sys.argv) < 5:
            print(f"Usage: python {sys.argv[0]} {project_path} plant <thread_name> <episode>")
            sys.exit(1)
        thread_name = sys.argv[3]
        episode = int(sys.argv[4])
        record_plant(project_path, thread_name, episode)
    elif command == 'payoff':
        if len(sys.argv) < 5:
            print(f"Usage: python {sys.argv[0]} {project_path} payoff <thread_name> <episode>")
            sys.exit(1)
        thread_name = sys.argv[3]
        episode = int(sys.argv[4])
        record_payoff(project_path, thread_name, episode)
    elif command == 'check':
        if len(sys.argv) < 4:
            print(f"Usage: python {sys.argv[0]} {project_path} check <episode>")
            sys.exit(1)
        episode = int(sys.argv[3])
        check_episode(project_path, episode)
    else:
        print(f"{RED}Unknown command: {command}{RESET}")
        sys.exit(1)


if __name__ == '__main__':
    main()
