#!/usr/bin/python3
"""
Transition Gate - Full Series Transition Validation

HARD GATE that runs ONLY after batch 12 (episodes 56-60).
Validates ALL 59 transitions across the complete series.

Checks:
1. Pattern-based hard failures (time skips after Mid-Action, location changes)
2. Semantic flags that need AI review (character continuity, action resolution)

This gate BLOCKS series completion if hard failures are found.

Usage: python3 transition_gate.py <project_path> [--check] [--fix-list]

Options:
    --check     Return exit code based on validation result
    --fix-list  Output /rewrite commands for fixing issues

Returns:
- Exit code 0: All transitions valid
- Exit code 1: Hard failures found (must fix)
- Exit code 2: Needs AI review (semantic issues)
"""

import json
import sys
import re
from pathlib import Path
from datetime import datetime


def extract_section(content, section_name):
    """Extract a specific section from episode content."""
    pattern = rf'#\s*\[[\d:]+ - [\d:]+\]\s*{section_name}\s*\n(.*?)(?=#\s*\[|---|\Z)'
    match = re.search(pattern, content, re.DOTALL | re.IGNORECASE)
    return match.group(1).strip() if match else ""


def extract_location(section_text):
    """Extract location from INT./EXT. line."""
    match = re.search(r'(?:INT\.|EXT\.)\s*(.+?)(?:\s*-|$)', section_text, re.MULTILINE)
    return match.group(1).strip().upper() if match else None


def load_episode(filepath):
    """Load and parse a single episode file."""
    with open(filepath, 'r', encoding='utf-8') as f:
        content = f.read()

    match = re.search(r'ep_(\d+)\.md', str(filepath))
    ep_num = int(match.group(1)) if match else 0

    # Extract cliffhanger type from metadata
    cliff_match = re.search(r'\*\*CLIFFHANGER TYPE:\*\*\s*(\w+)\s*\((\w)\)', content)
    cliff_type = cliff_match.group(2) if cliff_match else '?'

    # Extract hook type from metadata
    hook_match = re.search(r'\*\*HOOK TYPE:\*\*\s*(.+?)(?:\n|\*\*)', content)
    hook_type = hook_match.group(1).strip() if hook_match else '?'

    hook_section = extract_section(content, 'THE HOOK')
    cliffhanger_section = extract_section(content, 'THE CLIFFHANGER')

    return {
        'number': ep_num,
        'filepath': str(filepath),
        'cliffhanger_type': cliff_type,
        'hook_type': hook_type,
        'hook': hook_section,
        'cliffhanger': cliffhanger_section,
        'hook_location': extract_location(hook_section),
        'cliffhanger_location': extract_location(cliffhanger_section),
    }


def check_pattern_violations(prev_ep, curr_ep):
    """Check for hard-fail pattern violations in a transition."""
    violations = []

    prev_type = prev_ep['cliffhanger_type']
    curr_hook = curr_ep['hook'].lower()

    # Time skip phrases
    time_skip_phrases = [
        'later', 'hours later', 'days later', 'the next morning',
        'when she woke', 'when he woke', 'morning came',
        'twenty-six days', 'twenty-seven days', 'weeks later'
    ]

    # Rule 1: Mid-Action must have immediate pickup
    if prev_type == 'M':
        hook_start = curr_hook[:150]
        for phrase in time_skip_phrases:
            if phrase in hook_start:
                violations.append({
                    'severity': 'HARD',
                    'type': 'time_skip_after_midaction',
                    'transition': f"{prev_ep['number']}→{curr_ep['number']}",
                    'issue': f"Time skip ('{phrase}') after Mid-Action cliffhanger",
                    'fix': "Rewrite hook to continue immediately from the cliffhanger action"
                })
                break

        # Rule 2: Location should match for Mid-Action
        if prev_ep['cliffhanger_location'] and curr_ep['hook_location']:
            prev_loc = prev_ep['cliffhanger_location']
            curr_loc = curr_ep['hook_location']
            if prev_loc != curr_loc and prev_loc not in curr_loc:
                violations.append({
                    'severity': 'HARD',
                    'type': 'location_change_after_midaction',
                    'transition': f"{prev_ep['number']}→{curr_ep['number']}",
                    'issue': f"Location change ({prev_loc} → {curr_loc}) after Mid-Action",
                    'fix': "Continue in same location or show the transition"
                })

    return violations


def check_semantic_issues(prev_ep, curr_ep):
    """Check for semantic issues that need AI review."""
    issues = []

    prev_type = prev_ep['cliffhanger_type']
    prev_cliff = prev_ep['cliffhanger'].lower()
    curr_hook = curr_ep['hook'].lower()

    # Character continuity check
    char_pattern = r'\b([A-Z]{2,})\b'
    cliff_chars = set(re.findall(char_pattern, prev_ep['cliffhanger']))
    hook_chars = set(re.findall(char_pattern, curr_ep['hook']))

    # Filter noise
    noise = {'INT', 'EXT', 'THE', 'AND', 'BUT', 'FOR', 'NOT', 'WITH', 'HIS', 'HER',
             'SHE', 'HE', 'THEY', 'RED', 'BLUE', 'WHITE', 'OVERLAY', 'POV', 'STATUS',
             'WARNING', 'ERROR', 'ALERT', 'SYSTEM', 'CONTINUOUS', 'LATER', 'MORNING',
             'AR', 'V.O', 'O.S', 'CONT', 'CUT', 'FADE', 'SLAM', 'CRASH', 'BOOM'}
    cliff_chars = cliff_chars - noise
    hook_chars = hook_chars - noise

    # For Mid-Action, main characters should appear in both
    if prev_type == 'M' and cliff_chars:
        missing_chars = cliff_chars - hook_chars
        for char in missing_chars:
            if prev_ep['cliffhanger'].count(char) >= 2:
                issues.append({
                    'severity': 'REVIEW',
                    'type': 'character_continuity',
                    'transition': f"{prev_ep['number']}→{curr_ep['number']}",
                    'issue': f"Character '{char}' in cliffhanger but not in hook",
                    'question': f"Is {char}'s absence intentional after Mid-Action?"
                })

    # Action resolution check for Mid-Action
    if prev_type == 'M':
        action_indicators = ['attacks', 'fires', 'runs', 'falls', 'reaches', 'grabs',
                           'moves', 'charges', 'strikes', 'lunges', 'dives']
        cliff_actions = [a for a in action_indicators if a in prev_cliff]

        if cliff_actions:
            resolution_indicators = ['hits', 'misses', 'catches', 'lands', 'dodges',
                                   'blocks', 'impact', 'crashes', 'slams', 'connects',
                                   'wide', 'barely', 'almost']
            hook_has_resolution = any(r in curr_hook for r in resolution_indicators)

            if not hook_has_resolution:
                issues.append({
                    'severity': 'REVIEW',
                    'type': 'action_resolution',
                    'transition': f"{prev_ep['number']}→{curr_ep['number']}",
                    'issue': f"Action ({', '.join(cliff_actions)}) may not be resolved in hook",
                    'question': "Does the hook show the immediate result of the action?"
                })

    # Separation handling
    separation_words = ['separated', 'apart', 'alone', 'gone', 'dragged away', 'lost']
    if any(w in prev_cliff for w in separation_words):
        reunion_words = ['find', 'found', 'together', 'reunite', 'voice', 'here', 'alone', 'without']
        if not any(w in curr_hook for w in reunion_words):
            issues.append({
                'severity': 'REVIEW',
                'type': 'separation_handling',
                'transition': f"{prev_ep['number']}→{curr_ep['number']}",
                'issue': "Separation not addressed in hook",
                'question': "Does the hook acknowledge the separation?"
            })

    return issues


def run_full_validation(project_path):
    """Run comprehensive transition validation on all 60 episodes."""
    episodes_dir = Path(project_path) / 'episodes'

    # Load all episodes
    episodes = []
    for ep_num in range(1, 61):
        filepath = episodes_dir / f'ep_{ep_num:03d}.md'
        if filepath.exists():
            episodes.append(load_episode(filepath))

    if len(episodes) < 60:
        return {
            'passed': False,
            'error': f"Only {len(episodes)} episodes found. Need all 60.",
            'hard_fails': [],
            'reviews': [],
            'transitions': []
        }

    episodes = sorted(episodes, key=lambda x: x['number'])

    hard_fails = []
    reviews = []
    all_transitions = []

    for i in range(1, len(episodes)):
        prev_ep = episodes[i-1]
        curr_ep = episodes[i]

        transition = {
            'from': prev_ep['number'],
            'to': curr_ep['number'],
            'cliffhanger_type': prev_ep['cliffhanger_type'],
            'hook_type': curr_ep['hook_type'],
            'issues': []
        }

        # Check pattern violations (hard fails)
        pattern_violations = check_pattern_violations(prev_ep, curr_ep)
        hard_fails.extend(pattern_violations)
        transition['issues'].extend(pattern_violations)

        # Check semantic issues (need review)
        semantic_issues = check_semantic_issues(prev_ep, curr_ep)
        reviews.extend(semantic_issues)
        transition['issues'].extend(semantic_issues)

        all_transitions.append(transition)

    return {
        'passed': len(hard_fails) == 0,
        'hard_fails': hard_fails,
        'reviews': reviews,
        'transitions': all_transitions,
        'summary': {
            'total_transitions': len(all_transitions),
            'hard_fail_count': len(hard_fails),
            'review_count': len(reviews)
        }
    }


def print_report(result, project_path):
    """Print formatted validation report."""
    project_name = Path(project_path).name

    print(f"\n{'='*70}")
    print(f"TRANSITION GATE - Full Series Validation")
    print(f"{'='*70}")
    print(f"Project: {project_name}")
    print(f"Transitions checked: {result['summary']['total_transitions']}")
    print(f"Hard failures: {result['summary']['hard_fail_count']}")
    print(f"Reviews needed: {result['summary']['review_count']}")
    print(f"{'='*70}")

    if result.get('error'):
        print(f"\n[ERROR] {result['error']}")
        return

    if result['hard_fails']:
        print("\n[HARD FAILURES - MUST FIX]")
        print("-" * 50)
        for issue in result['hard_fails']:
            print(f"\nEp {issue['transition']}: {issue['issue']}")
            print(f"  Fix: {issue['fix']}")
            ep_num = issue['transition'].split('→')[1]
            print(f"  Command: /rewrite {project_name} ep {ep_num} \"{issue['issue']}\"")

    if result['reviews']:
        print("\n[NEEDS AI REVIEW]")
        print("-" * 50)
        for issue in result['reviews'][:10]:  # Show first 10
            print(f"\nEp {issue['transition']}: {issue['issue']}")
            print(f"  Review: {issue['question']}")
        if len(result['reviews']) > 10:
            print(f"\n  ... and {len(result['reviews']) - 10} more to review")

    print(f"\n{'='*70}")
    if result['passed'] and result['summary']['review_count'] == 0:
        print("[PASSED] All transitions validated successfully!")
        print("Series is ready for compilation.")
    elif result['passed']:
        print(f"[PASSED WITH REVIEWS] {result['summary']['review_count']} transitions need verification")
        print("AI should verify these are intentional before marking complete.")
    else:
        print(f"[FAILED] {result['summary']['hard_fail_count']} hard failures must be fixed")
        print("Series CANNOT be marked complete until these are resolved.")
    print(f"{'='*70}\n")


def print_fix_commands(result, project_path):
    """Print rewrite commands for fixing issues."""
    project_name = Path(project_path).name

    if result['hard_fails']:
        print("\n# Rewrite commands for hard failures:")
        for issue in result['hard_fails']:
            ep_num = issue['transition'].split('→')[1]
            print(f'/rewrite {project_name} ep {ep_num} "{issue["issue"]}"')

    if result['reviews']:
        print("\n# Review commands (verify these are intentional):")
        for issue in result['reviews']:
            ep_num = issue['transition'].split('→')[1]
            print(f"# Ep {ep_num}: {issue['question']}")


def main():
    if len(sys.argv) < 2:
        print("Usage: python3 transition_gate.py <project_path> [--check] [--fix-list]")
        print("\nOptions:")
        print("  --check     Return exit code based on validation")
        print("  --fix-list  Output rewrite commands for fixing")
        sys.exit(1)

    project_path = Path(sys.argv[1]).resolve()
    check_mode = '--check' in sys.argv
    fix_list = '--fix-list' in sys.argv

    if not project_path.exists():
        print(f"Error: Project path does not exist: {project_path}")
        sys.exit(1)

    result = run_full_validation(project_path)

    if fix_list:
        print_fix_commands(result, project_path)
    else:
        print_report(result, project_path)

    # Save report to state directory
    state_dir = project_path / 'state'
    state_dir.mkdir(parents=True, exist_ok=True)
    report_file = state_dir / 'transition_gate_report.json'

    with open(report_file, 'w') as f:
        json.dump({
            'timestamp': datetime.now().isoformat(),
            'passed': result['passed'],
            'summary': result['summary'],
            'hard_fails': result['hard_fails'],
            'reviews': result['reviews']
        }, f, indent=2)

    if check_mode:
        if not result['passed']:
            sys.exit(1)
        elif result['summary']['review_count'] > 0:
            sys.exit(2)
        else:
            sys.exit(0)


if __name__ == '__main__':
    main()
