#!/usr/bin/env python3
"""
Treatment Validation Tool
Validates treatment.md prose format for generation readiness.

Usage:
    python validate_treatment.py <project_path>
    python validate_treatment.py <project_path> --flag-weak
    python validate_treatment.py ./olympus

Requirements (Hard Gates):
- All episodes present (per CONSTANTS.md TOTAL_EPISODES)
- Each episode has metadata line (Sequence, Beat, Hook, Cliffhanger)
- Each episode has THE MOMENT line
- Each episode has prose paragraph (word count by beat type)
- Each episode has bracketed cliffhanger image
- Thread markers reference THREAD INDEX
- All threads have PLANT and PAYOFF
- Hook distribution 70-85% SILENT
- Cliffhanger distribution 70-85% MID-ACTION
- No 4+ consecutive same hook/cliffhanger SUBTYPE (max 3 allowed; see CONSTANTS.md)
  (Variety means subtype variety — a run of M-CF, M-PT, M-CT is fine;
   a run of 4x M-PT is not. Same main category with different subtypes = good variety.)

Soft Flags:
- Prose under/over word count for beat type
- Vague language ("tensions rise", etc.)
- Weak THE MOMENT (generic, not visual)
- Beat type 4+ consecutive
- Main-category consecutive runs (informational — not a violation if subtypes vary)
"""

import sys
import os
import re
from pathlib import Path
from collections import defaultdict

# Import constants from shared module (reads from CONSTANTS.md)
try:
    from engine_constants import (
        TREATMENT_WORD_COUNT_RANGES,
        KEY_EPISODE_WORD_COUNTS,
        TREATMENT_WORD_COUNT_TOLERANCE,
        TREATMENT_TOTAL_WORDS_MIN,
        TREATMENT_TOTAL_WORDS_MAX,
        MAX_CONSECUTIVE_SAME_TYPE,
        HOOK_SILENT_VALIDATION_MIN,
        HOOK_SILENT_VALIDATION_MAX,
        CLIFFHANGER_MIDACTION_VALIDATION_MIN,
        CLIFFHANGER_MIDACTION_VALIDATION_MAX,
        TOTAL_EPISODES,
    )
    # Aliases for backward compatibility
    WORD_COUNT_RANGES = TREATMENT_WORD_COUNT_RANGES
    WORD_COUNT_TOLERANCE = TREATMENT_WORD_COUNT_TOLERANCE
    TOTAL_WORD_COUNT_TARGET = (TREATMENT_TOTAL_WORDS_MIN + TREATMENT_TOTAL_WORDS_MAX) // 2
    TOTAL_WORD_COUNT_TOLERANCE = (TREATMENT_TOTAL_WORDS_MAX - TREATMENT_TOTAL_WORDS_MIN) // 2
except ImportError:
    # Fallback if engine_constants not available
    print("WARNING: Could not import engine_constants, using fallback values")
    WORD_COUNT_RANGES = {
        'SETUP': (40, 55),
        'COMPLICATION': (40, 55),
        'CATALYST': (50, 70),
        'LOCK-IN': (50, 70),
        'COLLISION': (60, 80),
        'CRISIS': (60, 80),
        'REVELATION': (60, 80),
        'CLIMAX': (70, 90),
        'RESOLUTION': (50, 70),
    }
    KEY_EPISODE_WORD_COUNTS = {
        1: (80, 100),
        10: (70, 90),
        15: (70, 90),
    }
    WORD_COUNT_TOLERANCE = 10
    TOTAL_WORD_COUNT_TARGET = 3500
    TOTAL_WORD_COUNT_TOLERANCE = 500
    MAX_CONSECUTIVE_SAME_TYPE = 3
    HOOK_SILENT_VALIDATION_MIN = 70
    HOOK_SILENT_VALIDATION_MAX = 85
    CLIFFHANGER_MIDACTION_VALIDATION_MIN = 70
    CLIFFHANGER_MIDACTION_VALIDATION_MAX = 85
    TOTAL_EPISODES = 61

# ANSI colors
GREEN = '\033[92m'
RED = '\033[91m'
YELLOW = '\033[93m'
CYAN = '\033[96m'
BOLD = '\033[1m'
RESET = '\033[0m'

# Valid beat types
VALID_BEAT_TYPES = ['SETUP', 'CATALYST', 'LOCK-IN', 'COMPLICATION', 'COLLISION', 'CRISIS', 'REVELATION', 'CLIMAX', 'RESOLUTION']

# Valid hook/cliffhanger subtypes (from Appendix A)
VALID_HOOK_SUBTYPES = {
    # Silent subtypes
    'S-VP': 'SILENT', 'S-SF': 'SILENT', 'S-UI': 'SILENT', 'S-CO': 'SILENT',
    'S-DE': 'SILENT', 'S-CT': 'SILENT', 'S-PV': 'SILENT',
    # Dialogue subtypes
    'D-PI': 'DIALOGUE', 'D-QU': 'DIALOGUE', 'D-DC': 'DIALOGUE', 'D-MC': 'DIALOGUE',
    # Main categories (backward compat)
    'SILENT': 'SILENT', 'DIALOGUE': 'DIALOGUE',
}

VALID_CLIFF_SUBTYPES = {
    # Mid-Action subtypes
    'M-PT': 'MID-ACTION', 'M-CT': 'MID-ACTION', 'M-CH': 'MID-ACTION',
    'M-PU': 'MID-ACTION', 'M-CF': 'MID-ACTION',
    # Aftermath subtypes
    'A-RE': 'AFTERMATH', 'A-CO': 'AFTERMATH', 'A-PS': 'AFTERMATH',
    'A-SI': 'AFTERMATH', 'A-CT': 'AFTERMATH', 'A-DE': 'AFTERMATH',
    # Main categories (backward compat)
    'MID-ACTION': 'MID-ACTION', 'AFTERMATH': 'AFTERMATH',
}

# Danger words for Ep 10 urgency check
DANGER_WORDS = [
    'fail', 'dying', 'death', 'collapse', 'attack', 'danger', 'threat',
    'kill', 'destroy', 'crash', 'fall', 'hurt', 'wound', 'bleed', 'shock',
    'alarm', 'warning', 'emergency', 'explode', 'fire', 'trap', 'chase',
    'hunt', 'escape', 'flee', 'run', 'panic', 'scream'
]

# Vague phrases to flag (word boundaries prevent matching inside compounds like "de-escalation")
VAGUE_PHRASES = [
    r'\btensions?\s+rise\b',
    r'\bthings?\s+escalat\w*\b',
    r'\brelationship\s+deep\w*\b',
    r'\brealizations?\s+occur\b',
    r'\bstakes?\s+increas\w*\b',
    r'\bsomething\s+chang\w*\b',
    r'\bthings?\s+get\s+(complicated|worse|better)\b',
    r'\bsituation\s+(develops|evolves|changes)\b',
    r'\bconflict\s+(grows|develops|intensifies)\b',
]
VAGUE_PATTERN = re.compile('|'.join(VAGUE_PHRASES), re.IGNORECASE)


def parse_treatment(filepath):
    """Parse treatment.md and extract episode data."""
    if not os.path.exists(filepath):
        return None, f"File not found: {filepath}"

    with open(filepath, 'r', encoding='utf-8') as f:
        content = f.read()

    episodes = []
    threads_in_index = set()

    # Extract THREAD INDEX entries from markdown table
    # Find the THREAD INDEX section
    thread_idx_start = content.find('## THREAD INDEX')
    thread_idx_end = content.find('\n## ', thread_idx_start + 1) if thread_idx_start != -1 else -1

    if thread_idx_start != -1:
        thread_section = content[thread_idx_start:thread_idx_end] if thread_idx_end != -1 else content[thread_idx_start:]

        # Parse markdown table rows - extract first column (Thread name)
        for line in thread_section.split('\n'):
            line = line.strip()
            # Skip header row, separator row, and empty lines
            if not line or line.startswith('|--') or '| Thread |' in line:
                continue
            # Table row should start with |
            if line.startswith('|'):
                cells = [c.strip() for c in line.split('|')]
                # cells[0] is empty (before first |), cells[1] is the Thread name
                if len(cells) > 1 and cells[1]:
                    thread_name = cells[1].strip()
                    # Skip if it looks like a header
                    if thread_name.lower() not in ['thread', '']:
                        threads_in_index.add(thread_name.lower())

    # Episode header pattern
    episode_header_pattern = re.compile(
        r'###\s*Episode\s*(\d+):\s*"([^"]+)"',
        re.IGNORECASE
    )

    # Metadata line pattern (pipe-separated)
    # Accepts both main categories (SILENT, MID-ACTION) and subtypes (S-VP, M-PT)
    hook_codes = '|'.join(sorted(VALID_HOOK_SUBTYPES.keys(), key=len, reverse=True))
    cliff_codes = '|'.join(sorted(VALID_CLIFF_SUBTYPES.keys(), key=len, reverse=True))
    metadata_pattern = re.compile(
        r'\*\*Sequence:\*\*\s*(\d+)\s*\|\s*\*\*Beat:\*\*\s*(\w+[-\w]*)\s*\|\s*\*\*Hook:\*\*\s*(' + hook_codes + r')\s*\|\s*\*\*Cliffhanger:\*\*\s*(' + cliff_codes + r')',
        re.IGNORECASE
    )

    # Threads line pattern
    threads_line_pattern = re.compile(
        r'\*\*Threads:\*\*\s*(.+?)(?:\n|$)',
        re.IGNORECASE
    )

    # Individual thread marker pattern
    thread_marker_pattern = re.compile(
        r'\[(PLANT|ADVANCE|PAYOFF):\s*([^\]]+)\]',
        re.IGNORECASE
    )

    # THE MOMENT pattern
    moment_pattern = re.compile(
        r'\*\*THE MOMENT:\*\*\s*(.+?)(?:\n|$)',
        re.IGNORECASE
    )

    # VOICE SEED pattern (Episode 1 only)
    voice_seed_pattern = re.compile(
        r'\*\*VOICE SEED:\*\*\s*(.+?)(?:\n|$)',
        re.IGNORECASE
    )

    # Cliffhanger image pattern
    cliffhanger_image_pattern = re.compile(
        r'\*\*\[CLIFFHANGER:\s*(.+?)\]\*\*',
        re.IGNORECASE
    )

    # Split content into episode blocks
    episode_blocks = re.split(r'(?=###\s*Episode\s*\d+)', content)

    for block in episode_blocks:
        header_match = episode_header_pattern.search(block)
        if not header_match:
            continue

        ep_num = int(header_match.group(1))
        title = header_match.group(2)

        episode_data = {
            'episode': ep_num,
            'title': title,
            'sequence': None,
            'beat_type': None,
            'hook_type': None,          # Main category: SILENT or DIALOGUE
            'hook_subtype': None,       # Subtype code: S-VP, D-PI, etc. (or main category if none)
            'cliffhanger_type': None,   # Main category: MID-ACTION or AFTERMATH
            'cliff_subtype': None,      # Subtype code: M-PT, A-RE, etc. (or main category if none)
            'threads': [],
            'the_moment': None,
            'voice_seed': None,
            'cliffhanger_image': None,
            'prose': None,
            'prose_word_count': 0,
            'has_metadata': False,
            'issues': []
        }

        # Extract metadata line
        meta_match = metadata_pattern.search(block)
        if meta_match:
            episode_data['sequence'] = int(meta_match.group(1))
            episode_data['beat_type'] = meta_match.group(2).upper()
            raw_hook = meta_match.group(3).upper()
            raw_cliff = meta_match.group(4).upper()

            # Resolve hook: subtype → main category
            episode_data['hook_subtype'] = raw_hook
            episode_data['hook_type'] = VALID_HOOK_SUBTYPES.get(raw_hook, raw_hook)
            if raw_hook not in VALID_HOOK_SUBTYPES:
                episode_data['issues'].append(f"Invalid hook code: {raw_hook}")

            # Resolve cliffhanger: subtype → main category
            episode_data['cliff_subtype'] = raw_cliff
            episode_data['cliffhanger_type'] = VALID_CLIFF_SUBTYPES.get(raw_cliff, raw_cliff)
            if raw_cliff not in VALID_CLIFF_SUBTYPES:
                episode_data['issues'].append(f"Invalid cliffhanger code: {raw_cliff}")

            episode_data['has_metadata'] = True

            if episode_data['beat_type'] not in VALID_BEAT_TYPES:
                episode_data['issues'].append(f"Invalid beat type: {episode_data['beat_type']}")

        # Extract threads
        threads_match = threads_line_pattern.search(block)
        if threads_match:
            threads_text = threads_match.group(1)
            for tm in thread_marker_pattern.finditer(threads_text):
                marker_type = tm.group(1).upper()
                thread_name = tm.group(2).strip()
                episode_data['threads'].append({
                    'type': marker_type,
                    'name': thread_name
                })
                # Check if thread is in index
                if thread_name.lower() not in threads_in_index:
                    episode_data['issues'].append(f"Thread '{thread_name}' not in THREAD INDEX")

        # Extract THE MOMENT
        moment_match = moment_pattern.search(block)
        if moment_match:
            episode_data['the_moment'] = moment_match.group(1).strip()

        # Extract VOICE SEED (Episode 1 only)
        voice_seed_match = voice_seed_pattern.search(block)
        if voice_seed_match:
            episode_data['voice_seed'] = voice_seed_match.group(1).strip()

        # Extract cliffhanger image
        cliff_match = cliffhanger_image_pattern.search(block)
        if cliff_match:
            episode_data['cliffhanger_image'] = cliff_match.group(1).strip()

        # Extract prose paragraph (text between metadata and cliffhanger)
        # Find the end of metadata lines and start of cliffhanger
        lines = block.split('\n')
        prose_lines = []
        in_prose = False

        for line in lines:
            # Skip header, metadata, threads, THE MOMENT, VOICE SEED lines
            if episode_header_pattern.match(line):
                continue
            if '**Sequence:**' in line or '**Threads:**' in line or '**THE MOMENT:**' in line or '**VOICE SEED:**' in line:
                in_prose = True
                continue
            if '**[CLIFFHANGER:' in line:
                in_prose = False
                continue
            if in_prose and line.strip() and not line.startswith('---'):
                prose_lines.append(line.strip())

        if prose_lines:
            episode_data['prose'] = ' '.join(prose_lines)
            episode_data['prose_word_count'] = len(episode_data['prose'].split())

            # Check for vague language
            if VAGUE_PATTERN.search(episode_data['prose']):
                matches = VAGUE_PATTERN.findall(episode_data['prose'])
                # Handle tuples from regex groups
                match_strs = [m if isinstance(m, str) else m[0] if m else '' for m in matches]
                match_strs = [m for m in match_strs if m]  # Filter empty
                if match_strs:
                    episode_data['issues'].append(f"Vague language: {', '.join(match_strs)}")

        episodes.append(episode_data)

    return {
        'episodes': episodes,
        'threads_in_index': threads_in_index
    }, None


def validate_treatment(project_path):
    """Run treatment format validation."""
    project_path = Path(project_path)

    # Find treatment.md
    treatment_path = project_path / 'treatment.md'
    if not treatment_path.exists():
        return {
            'passed': False,
            'errors': [f'treatment.md not found at {treatment_path}'],
            'episodes': [],
            'hard_gates': {},
            'soft_flags': {}
        }

    data, err = parse_treatment(treatment_path)
    if err:
        return {
            'passed': False,
            'errors': [err],
            'episodes': [],
            'hard_gates': {},
            'soft_flags': {}
        }

    episodes = data['episodes']
    threads_in_index = data['threads_in_index']

    results = {
        'passed': True,
        'errors': [],
        'episodes': episodes,
        'hard_gates': {},
        'soft_flags': {},
        'treatment_path': str(treatment_path)
    }

    # === HARD GATES ===

    # 1. Coverage - all episodes present (per CONSTANTS.md TOTAL_EPISODES)
    found_eps = [ep['episode'] for ep in episodes]
    missing_eps = [i for i in range(1, TOTAL_EPISODES + 1) if i not in found_eps]
    results['hard_gates']['coverage'] = {
        'found': len(episodes),
        'required': TOTAL_EPISODES,
        'passed': len(missing_eps) == 0,
        'missing': missing_eps
    }

    # 2. Metadata - all episodes have metadata line
    eps_with_metadata = [ep for ep in episodes if ep['has_metadata']]
    results['hard_gates']['metadata'] = {
        'count': len(eps_with_metadata),
        'required': TOTAL_EPISODES,
        'passed': len(eps_with_metadata) >= TOTAL_EPISODES,
        'missing': [ep['episode'] for ep in episodes if not ep['has_metadata']]
    }

    # 3. THE MOMENT - all episodes have it
    eps_with_moment = [ep for ep in episodes if ep['the_moment']]
    results['hard_gates']['the_moment'] = {
        'count': len(eps_with_moment),
        'required': TOTAL_EPISODES,
        'passed': len(eps_with_moment) >= TOTAL_EPISODES,
        'missing': [ep['episode'] for ep in episodes if not ep['the_moment']]
    }

    # 4. VOICE SEED - Episode 1 must have it
    ep1 = next((ep for ep in episodes if ep['episode'] == 1), None)
    voice_seed_passed = ep1 is not None and ep1.get('voice_seed') is not None
    results['hard_gates']['voice_seed'] = {
        'passed': voice_seed_passed,
        'episode_1_has_seed': voice_seed_passed,
        'seed': ep1.get('voice_seed') if ep1 else None
    }

    # 5. Key episode word counts - Ep 1, 10, 15 must meet minimums
    key_ep_issues = []
    for ep_num, (min_wc, max_wc) in KEY_EPISODE_WORD_COUNTS.items():
        ep = next((e for e in episodes if e['episode'] == ep_num), None)
        if ep:
            if ep['prose_word_count'] < min_wc - WORD_COUNT_TOLERANCE:
                key_ep_issues.append({
                    'episode': ep_num,
                    'count': ep['prose_word_count'],
                    'minimum': min_wc,
                    'issue': f'under minimum ({min_wc})'
                })
    results['hard_gates']['key_episode_words'] = {
        'passed': len(key_ep_issues) == 0,
        'issues': key_ep_issues
    }

    # 6. Total word count - treatment should be 3000-4000 words for ~15-20 min read
    total_words = sum(ep['prose_word_count'] for ep in episodes)
    min_total = TOTAL_WORD_COUNT_TARGET - TOTAL_WORD_COUNT_TOLERANCE
    max_total = TOTAL_WORD_COUNT_TARGET + TOTAL_WORD_COUNT_TOLERANCE
    total_word_passed = min_total <= total_words <= max_total

    results['hard_gates']['total_word_count'] = {
        'count': total_words,
        'target': TOTAL_WORD_COUNT_TARGET,
        'range': f'{min_total}-{max_total}',
        'passed': total_word_passed,
        'read_time_min': total_words // 200
    }

    # 7. Prose word counts - within range for beat type (skip key episodes)
    word_count_issues = []
    for ep in episodes:
        # Skip key episodes - they're checked separately
        if ep['episode'] in KEY_EPISODE_WORD_COUNTS:
            continue
        if ep['beat_type'] and ep['prose_word_count'] > 0:
            min_wc, max_wc = WORD_COUNT_RANGES.get(ep['beat_type'], (60, 80))
            min_with_tolerance = min_wc - WORD_COUNT_TOLERANCE
            max_with_tolerance = max_wc + WORD_COUNT_TOLERANCE
            if ep['prose_word_count'] < min_with_tolerance:
                word_count_issues.append({
                    'episode': ep['episode'],
                    'count': ep['prose_word_count'],
                    'beat': ep['beat_type'],
                    'range': f"{min_wc}-{max_wc}",
                    'issue': 'under'
                })
            elif ep['prose_word_count'] > max_with_tolerance:
                word_count_issues.append({
                    'episode': ep['episode'],
                    'count': ep['prose_word_count'],
                    'beat': ep['beat_type'],
                    'range': f"{min_wc}-{max_wc}",
                    'issue': 'over'
                })

    # Word counts moved to soft flags - treatments are for human review
    # and prose length is flexible as long as content is captured
    results['soft_flags']['word_counts'] = {
        'count': len(word_count_issues),
        'issues': word_count_issues
    }

    # 7. Cliffhanger images - all episodes have them
    eps_with_cliff_image = [ep for ep in episodes if ep['cliffhanger_image']]
    results['hard_gates']['cliffhanger_images'] = {
        'count': len(eps_with_cliff_image),
        'required': TOTAL_EPISODES,
        'passed': len(eps_with_cliff_image) >= TOTAL_EPISODES,
        'missing': [ep['episode'] for ep in episodes if not ep['cliffhanger_image']]
    }

    # 6. Thread coherence - all markers reference index
    thread_issues = []
    for ep in episodes:
        for issue in ep['issues']:
            if 'not in THREAD INDEX' in issue:
                thread_issues.append({'episode': ep['episode'], 'issue': issue})

    results['hard_gates']['thread_coherence'] = {
        'passed': len(thread_issues) == 0,
        'issues': thread_issues
    }

    # 7. All threads have PLANT and PAYOFF
    thread_markers = defaultdict(lambda: {'plant': [], 'advance': [], 'payoff': []})
    for ep in episodes:
        for thread in ep['threads']:
            marker_type = thread['type'].lower()
            thread_name = thread['name'].lower()
            thread_markers[thread_name][marker_type].append(ep['episode'])

    unresolved_threads = []
    for thread_name, markers in thread_markers.items():
        if not markers['plant']:
            unresolved_threads.append({'thread': thread_name, 'issue': 'no PLANT'})
        if not markers['payoff']:
            unresolved_threads.append({'thread': thread_name, 'issue': 'no PAYOFF'})

    results['hard_gates']['thread_resolution'] = {
        'passed': len(unresolved_threads) == 0,
        'issues': unresolved_threads,
        'threads': dict(thread_markers)
    }

    # 8. Hook ratio - 70-85% SILENT
    hooks = [ep['hook_type'] for ep in episodes if ep['hook_type']]
    silent_count = hooks.count('SILENT')
    dialogue_count = hooks.count('DIALOGUE')
    total_hooks = len(hooks)

    if total_hooks > 0:
        silent_pct = (silent_count / total_hooks) * 100
        hook_passed = HOOK_SILENT_VALIDATION_MIN <= silent_pct <= HOOK_SILENT_VALIDATION_MAX
    else:
        silent_pct = 0
        hook_passed = False

    results['hard_gates']['hook_ratio'] = {
        'silent_count': silent_count,
        'dialogue_count': dialogue_count,
        'silent_pct': round(silent_pct, 1),
        'target_range': f'{HOOK_SILENT_VALIDATION_MIN}-{HOOK_SILENT_VALIDATION_MAX}%',
        'passed': hook_passed
    }

    # 9. Cliffhanger ratio - 70-85% MID-ACTION
    cliffs = [ep['cliffhanger_type'] for ep in episodes if ep['cliffhanger_type']]
    midaction_count = cliffs.count('MID-ACTION')
    aftermath_count = cliffs.count('AFTERMATH')
    total_cliffs = len(cliffs)

    if total_cliffs > 0:
        midaction_pct = (midaction_count / total_cliffs) * 100
        cliff_passed = CLIFFHANGER_MIDACTION_VALIDATION_MIN <= midaction_pct <= CLIFFHANGER_MIDACTION_VALIDATION_MAX
    else:
        midaction_pct = 0
        cliff_passed = False

    results['hard_gates']['cliffhanger_ratio'] = {
        'midaction_count': midaction_count,
        'aftermath_count': aftermath_count,
        'midaction_pct': round(midaction_pct, 1),
        'target_range': f'{CLIFFHANGER_MIDACTION_VALIDATION_MIN}-{CLIFFHANGER_MIDACTION_VALIDATION_MAX}%',
        'passed': cliff_passed
    }

    # 10. Pattern variety — SUBTYPE consecutive check (hard gate)
    # Variety means SUBTYPE variety: M-CF, M-PT, M-CT in a row = FINE (different kinds of mid-action)
    # 4x M-PT in a row = violation (same subtype repeated)
    # NOTE: First 10 Episodes Rule - Ep 1-10 are exempt from cliffhanger pattern variety
    def check_consecutive(items, max_allowed=3):
        if not items:
            return True, []
        violations = []
        current_type = items[0]
        count = 1
        start_idx = 0

        for i in range(1, len(items)):
            if items[i] == current_type:
                count += 1
            else:
                if count > max_allowed:
                    violations.append({
                        'type': current_type,
                        'start': start_idx + 1,
                        'end': i,
                        'count': count
                    })
                current_type = items[i]
                count = 1
                start_idx = i

        # Check the final run
        if count > max_allowed:
            violations.append({
                'type': current_type,
                'start': start_idx + 1,
                'end': len(items),
                'count': count
            })

        return len(violations) == 0, violations

    sorted_eps = sorted(episodes, key=lambda x: x['episode'])

    # HARD GATE: Subtype consecutive check (no 4+ same subtype in a row)
    # This is the real variety check — different subtypes within a category = good variety
    hook_subtype_seq = [ep['hook_subtype'] for ep in sorted_eps if ep['hook_subtype']]

    # For cliffhangers, only check episodes 11+ (First 10 Episodes Rule)
    cliff_sub_eps_11_plus = [(ep['episode'], ep['cliff_subtype']) for ep in sorted_eps
                              if ep['cliff_subtype'] and ep['episode'] > 10]
    cliff_subtype_seq_11_plus = [cs for _, cs in cliff_sub_eps_11_plus]

    hook_sub_ok, hook_sub_violations = check_consecutive(hook_subtype_seq)
    cliff_sub_ok, cliff_sub_violations = check_consecutive(cliff_subtype_seq_11_plus)

    # Map cliffhanger subtype violation indices back to actual episode numbers
    for v in cliff_sub_violations:
        v['start'] = cliff_sub_eps_11_plus[v['start'] - 1][0] if v['start'] - 1 < len(cliff_sub_eps_11_plus) else v['start'] + 10
        v['end'] = cliff_sub_eps_11_plus[v['end'] - 1][0] if v['end'] - 1 < len(cliff_sub_eps_11_plus) else v['end'] + 10

    results['hard_gates']['pattern_variety'] = {
        'passed': hook_sub_ok and cliff_sub_ok,
        'hook_violations': hook_sub_violations,
        'cliffhanger_violations': cliff_sub_violations
    }

    # SOFT FLAG: Main-category consecutive runs (informational only)
    # A run of 5x MID-ACTION with different subtypes is fine — just flag for awareness
    hook_main_seq = [ep['hook_type'] for ep in sorted_eps if ep['hook_type']]
    cliff_main_eps_11_plus = [(ep['episode'], ep['cliffhanger_type']) for ep in sorted_eps
                               if ep['cliffhanger_type'] and ep['episode'] > 10]
    cliff_main_seq_11_plus = [ct for _, ct in cliff_main_eps_11_plus]

    hook_main_ok, hook_main_violations = check_consecutive(hook_main_seq)
    cliff_main_ok, cliff_main_violations = check_consecutive(cliff_main_seq_11_plus)

    # Map main-category cliff violation indices back to episode numbers
    for v in cliff_main_violations:
        v['start'] = cliff_main_eps_11_plus[v['start'] - 1][0] if v['start'] - 1 < len(cliff_main_eps_11_plus) else v['start'] + 10
        v['end'] = cliff_main_eps_11_plus[v['end'] - 1][0] if v['end'] - 1 < len(cliff_main_eps_11_plus) else v['end'] + 10

    results['soft_flags']['main_category_runs'] = {
        'hook_violations': hook_main_violations,
        'cliff_violations': cliff_main_violations,
        'count': len(hook_main_violations) + len(cliff_main_violations)
    }

    # Calculate overall pass (per-episode word counts are soft flags, not hard gates)
    all_passed = all([
        results['hard_gates']['coverage']['passed'],
        results['hard_gates']['metadata']['passed'],
        results['hard_gates']['the_moment']['passed'],
        results['hard_gates']['voice_seed']['passed'],
        results['hard_gates']['key_episode_words']['passed'],
        results['hard_gates']['total_word_count']['passed'],
        results['hard_gates']['cliffhanger_images']['passed'],
        results['hard_gates']['thread_coherence']['passed'],
        results['hard_gates']['thread_resolution']['passed'],
        results['hard_gates']['hook_ratio']['passed'],
        results['hard_gates']['cliffhanger_ratio']['passed'],
        results['hard_gates']['pattern_variety']['passed']
    ])
    results['passed'] = all_passed

    # === SOFT FLAGS ===

    # 1. Vague language
    vague_eps = [ep for ep in episodes if any('Vague language' in issue for issue in ep['issues'])]
    results['soft_flags']['vague_language'] = {
        'count': len(vague_eps),
        'episodes': [{'episode': ep['episode'], 'issues': [i for i in ep['issues'] if 'Vague' in i]} for ep in vague_eps]
    }

    # 2. Beat type patterns (4+ consecutive)
    beat_seq = [(ep['episode'], ep['beat_type']) for ep in sorted_eps if ep['beat_type']]
    consecutive_beats = []
    if beat_seq:
        current_type = beat_seq[0][1]
        current_run = [beat_seq[0][0]]
        for ep_num, beat_type in beat_seq[1:]:
            if beat_type == current_type:
                current_run.append(ep_num)
            else:
                if len(current_run) >= 4:
                    consecutive_beats.append({
                        'type': current_type,
                        'episodes': current_run.copy(),
                        'count': len(current_run)
                    })
                current_type = beat_type
                current_run = [ep_num]
        if len(current_run) >= 4:
            consecutive_beats.append({
                'type': current_type,
                'episodes': current_run,
                'count': len(current_run)
            })

    results['soft_flags']['beat_patterns'] = {
        'count': len(consecutive_beats),
        'sequences': consecutive_beats
    }

    # 3. Weak THE MOMENT (generic or too short)
    weak_moments = []
    weak_moment_patterns = ['develops', 'changes', 'happens', 'occurs', 'begins']
    for ep in episodes:
        if ep['the_moment']:
            moment_lower = ep['the_moment'].lower()
            if len(ep['the_moment']) < 20:
                weak_moments.append({'episode': ep['episode'], 'moment': ep['the_moment'], 'issue': 'too short'})
            elif any(p in moment_lower for p in weak_moment_patterns):
                weak_moments.append({'episode': ep['episode'], 'moment': ep['the_moment'], 'issue': 'generic language'})

    results['soft_flags']['weak_moments'] = {
        'count': len(weak_moments),
        'episodes': weak_moments
    }

    # 4. Episode 10 urgency - cliffhanger should have danger element
    ep10 = next((ep for ep in episodes if ep['episode'] == 10), None)
    ep10_urgency_ok = False
    if ep10 and ep10['cliffhanger_image']:
        cliff_lower = ep10['cliffhanger_image'].lower()
        ep10_urgency_ok = any(word in cliff_lower for word in DANGER_WORDS)

    results['soft_flags']['ep10_urgency'] = {
        'passed': ep10_urgency_ok,
        'cliffhanger': ep10['cliffhanger_image'] if ep10 else None,
        'note': 'Ep 10 cliffhanger should have physical danger element (paywall conversion)'
    }

    # 5. Continuity check - flag cliffhanger → hook transitions that might need review
    # MID-ACTION cliffhanger should typically resolve in SILENT hook (continuing the action)
    # AFTERMATH cliffhanger can transition to either hook type
    continuity_flags = []
    for i in range(len(sorted_eps) - 1):
        prev_ep = sorted_eps[i]
        next_ep = sorted_eps[i + 1]

        if prev_ep['cliffhanger_type'] and next_ep['hook_type']:
            # Flag potential discontinuity: MID-ACTION → DIALOGUE might lose tension
            if prev_ep['cliffhanger_type'] == 'MID-ACTION' and next_ep['hook_type'] == 'DIALOGUE':
                continuity_flags.append({
                    'prev_ep': prev_ep['episode'],
                    'next_ep': next_ep['episode'],
                    'prev_cliff': prev_ep['cliffhanger_type'],
                    'next_hook': next_ep['hook_type'],
                    'prev_cliff_image': prev_ep['cliffhanger_image'][:50] if prev_ep['cliffhanger_image'] else None,
                    'note': 'MID-ACTION → DIALOGUE may lose tension (review recommended)'
                })

    results['soft_flags']['continuity'] = {
        'flags': continuity_flags,
        'count': len(continuity_flags),
        'note': 'MID-ACTION cliffhangers should typically resolve visually, not with dialogue'
    }

    return results


def print_report(results, project_name):
    """Print formatted validation report."""
    print(f"\n{'═' * 65}")
    print(f"{BOLD}TREATMENT VALIDATION: {project_name.upper()}{RESET}")
    print(f"{'═' * 65}\n")

    if results['errors']:
        print(f"{RED}Errors:{RESET}")
        for err in results['errors']:
            print(f"  - {err}")
        return False

    print(f"{BOLD}HARD GATES{RESET}")
    print(f"{'─' * 65}\n")

    # Coverage
    cov = results['hard_gates']['coverage']
    status = f"{GREEN}✓{RESET}" if cov['passed'] else f"{RED}✗{RESET}"
    print(f"COVERAGE")
    print(f"  Episodes found: {cov['found']}/{cov['required']} {status}")
    if cov['missing']:
        print(f"  {RED}Missing: Ep {', '.join(map(str, cov['missing'][:10]))}{RESET}")

    # Metadata
    meta = results['hard_gates']['metadata']
    status = f"{GREEN}✓{RESET}" if meta['passed'] else f"{RED}✗{RESET}"
    print(f"\nMETADATA LINES")
    print(f"  Complete: {meta['count']}/{meta['required']} {status}")
    if meta['missing']:
        print(f"  {RED}Missing: Ep {', '.join(map(str, meta['missing'][:10]))}{RESET}")

    # THE MOMENT
    mom = results['hard_gates']['the_moment']
    status = f"{GREEN}✓{RESET}" if mom['passed'] else f"{RED}✗{RESET}"
    print(f"\nTHE MOMENT")
    print(f"  Present: {mom['count']}/{mom['required']} {status}")
    if mom['missing']:
        print(f"  {RED}Missing: Ep {', '.join(map(str, mom['missing'][:10]))}{RESET}")

    # VOICE SEED
    vs = results['hard_gates']['voice_seed']
    status = f"{GREEN}✓{RESET}" if vs['passed'] else f"{RED}✗{RESET}"
    print(f"\nVOICE SEED (Episode 1)")
    print(f"  Episode 1 has seed: {status}")
    if vs['seed']:
        print(f"  Seed: \"{vs['seed'][:50]}...\"" if len(vs['seed']) > 50 else f"  Seed: \"{vs['seed']}\"")
    elif not vs['passed']:
        print(f"  {RED}Missing: Add **VOICE SEED:** line to Episode 1{RESET}")

    # KEY EPISODE WORDS
    kew = results['hard_gates']['key_episode_words']
    status = f"{GREEN}✓{RESET}" if kew['passed'] else f"{RED}✗{RESET}"
    print(f"\nKEY EPISODE WORD COUNTS (Ep 1, 10, 15)")
    print(f"  All meet minimums: {status}")
    if kew['issues']:
        for issue in kew['issues']:
            print(f"  {RED}Ep {issue['episode']}: {issue['count']} words ({issue['issue']}){RESET}")

    # TOTAL WORD COUNT
    twc = results['hard_gates']['total_word_count']
    status = f"{GREEN}✓{RESET}" if twc['passed'] else f"{RED}✗{RESET}"
    print(f"\nTOTAL WORD COUNT")
    print(f"  Words: {twc['count']:,} (~{twc['read_time_min']} min read)")
    print(f"  Target: {twc['range']} words {status}")

    # Cliffhanger images
    cliff_img = results['hard_gates']['cliffhanger_images']
    status = f"{GREEN}✓{RESET}" if cliff_img['passed'] else f"{RED}✗{RESET}"
    print(f"\nCLIFFHANGER IMAGES")
    print(f"  Present: {cliff_img['count']}/{cliff_img['required']} {status}")
    if cliff_img['missing']:
        print(f"  {RED}Missing: Ep {', '.join(map(str, cliff_img['missing'][:10]))}{RESET}")

    # Thread coherence
    tc = results['hard_gates']['thread_coherence']
    status = f"{GREEN}✓{RESET}" if tc['passed'] else f"{RED}✗{RESET}"
    print(f"\nTHREAD COHERENCE")
    print(f"  All reference index: {status}")
    if tc['issues']:
        for issue in tc['issues'][:5]:
            print(f"  {RED}Ep {issue['episode']}: {issue['issue']}{RESET}")

    # Thread resolution
    tr = results['hard_gates']['thread_resolution']
    status = f"{GREEN}✓{RESET}" if tr['passed'] else f"{RED}✗{RESET}"
    print(f"\nTHREAD RESOLUTION")
    print(f"  All have PLANT+PAYOFF: {status}")
    if tr['issues']:
        for issue in tr['issues'][:5]:
            print(f"  {RED}'{issue['thread']}': {issue['issue']}{RESET}")

    # Hook ratio
    hr = results['hard_gates']['hook_ratio']
    status = f"{GREEN}✓{RESET}" if hr['passed'] else f"{RED}✗{RESET}"
    print(f"\nHOOK DISTRIBUTION")
    print(f"  SILENT:   {hr['silent_count']} ({hr['silent_pct']}%)")
    print(f"  DIALOGUE: {hr['dialogue_count']}")
    print(f"  Target:   {hr['target_range']} SILENT {status}")

    # Cliffhanger ratio
    cr = results['hard_gates']['cliffhanger_ratio']
    status = f"{GREEN}✓{RESET}" if cr['passed'] else f"{RED}✗{RESET}"
    print(f"\nCLIFFHANGER DISTRIBUTION")
    print(f"  MID-ACTION: {cr['midaction_count']} ({cr['midaction_pct']}%)")
    print(f"  AFTERMATH:  {cr['aftermath_count']}")
    print(f"  Target:     {cr['target_range']} MID-ACTION {status}")

    # Pattern variety (subtype-level — the real variety check)
    pv = results['hard_gates']['pattern_variety']
    status = f"{GREEN}✓{RESET}" if pv['passed'] else f"{RED}✗{RESET}"
    print(f"\nPATTERN VARIETY (max 3 consecutive same subtype)")
    print(f"  All subtypes varied: {status}")
    if pv['hook_violations']:
        for v in pv['hook_violations']:
            print(f"  {RED}Hook: {v['count']}x {v['type']} (Ep {v['start']}-{v['end']}){RESET}")
    if pv['cliffhanger_violations']:
        for v in pv['cliffhanger_violations']:
            print(f"  {RED}Cliff: {v['count']}x {v['type']} (Ep {v['start']}-{v['end']}){RESET}")

    # Subtype distribution (informational)
    hook_subtypes = defaultdict(int)
    cliff_subtypes = defaultdict(int)
    for ep in results['episodes']:
        if ep.get('hook_subtype'):
            hook_subtypes[ep['hook_subtype']] += 1
        if ep.get('cliff_subtype'):
            cliff_subtypes[ep['cliff_subtype']] += 1

    has_subtypes = any(k not in ('SILENT', 'DIALOGUE') for k in hook_subtypes) or \
                   any(k not in ('MID-ACTION', 'AFTERMATH') for k in cliff_subtypes)

    if has_subtypes:
        print(f"\nSUBTYPE DISTRIBUTION")
        if hook_subtypes:
            hook_items = sorted(hook_subtypes.items(), key=lambda x: -x[1])
            hook_str = ', '.join(f"{k}:{v}" for k, v in hook_items)
            print(f"  Hooks: {hook_str}")
        if cliff_subtypes:
            cliff_items = sorted(cliff_subtypes.items(), key=lambda x: -x[1])
            cliff_str = ', '.join(f"{k}:{v}" for k, v in cliff_items)
            print(f"  Cliffs: {cliff_str}")

    # Soft flags
    print(f"\n{BOLD}SOFT FLAGS{RESET}")
    print(f"{'─' * 65}\n")

    # Word counts (soft flag - treatments are for human review)
    wc = results['soft_flags']['word_counts']
    if wc['count'] > 0:
        print(f"{YELLOW}PROSE WORD COUNTS: {wc['count']} episodes outside target range{RESET}")
        for issue in wc['issues'][:3]:
            print(f"  Ep {issue['episode']}: {issue['count']} words ({issue['beat']} target {issue['range']})")
        if len(wc['issues']) > 3:
            print(f"    ...and {len(wc['issues']) - 3} more")
    else:
        print(f"PROSE WORD COUNTS: All in range {GREEN}✓{RESET}")

    # Vague language
    vl = results['soft_flags']['vague_language']
    if vl['count'] > 0:
        print(f"\n{YELLOW}VAGUE LANGUAGE: {vl['count']} episodes{RESET}")
        for ep in vl['episodes'][:5]:
            print(f"  Ep {ep['episode']}: {ep['issues'][0]}")
    else:
        print(f"\nVAGUE LANGUAGE: None detected {GREEN}✓{RESET}")

    # Beat patterns
    bp = results['soft_flags']['beat_patterns']
    if bp['count'] > 0:
        print(f"\n{YELLOW}CONSECUTIVE BEAT PATTERNS (4+): {bp['count']} sequences{RESET}")
        for seq in bp['sequences']:
            ep_range = f"Ep {seq['episodes'][0]}-{seq['episodes'][-1]}"
            print(f"  {seq['count']}x {seq['type']} ({ep_range})")
    else:
        print(f"\nCONSECUTIVE BEAT PATTERNS: None detected {GREEN}✓{RESET}")

    # Weak moments
    wm = results['soft_flags']['weak_moments']
    if wm['count'] > 0:
        print(f"\n{YELLOW}WEAK THE MOMENTS: {wm['count']} episodes{RESET}")
        for ep in wm['episodes'][:5]:
            print(f"  Ep {ep['episode']}: \"{ep['moment'][:40]}...\" ({ep['issue']})")
    else:
        print(f"\nWEAK THE MOMENTS: None detected {GREEN}✓{RESET}")

    # Episode 10 urgency
    ep10u = results['soft_flags']['ep10_urgency']
    if not ep10u['passed']:
        print(f"\n{YELLOW}EPISODE 10 URGENCY: May lack physical danger element{RESET}")
        if ep10u['cliffhanger']:
            print(f"  Current: \"{ep10u['cliffhanger'][:60]}...\"" if len(ep10u['cliffhanger']) > 60 else f"  Current: \"{ep10u['cliffhanger']}\"")
        print(f"  {CYAN}Tip: Ep 10 is the paywall - cliffhanger should have urgency{RESET}")
    else:
        print(f"\nEPISODE 10 URGENCY: Has danger element {GREEN}✓{RESET}")

    # Continuity check
    cont = results['soft_flags']['continuity']
    if cont['count'] > 0:
        print(f"\n{YELLOW}CONTINUITY FLAGS: {cont['count']} transitions may need review{RESET}")
        for flag in cont['flags'][:5]:
            print(f"  Ep {flag['prev_ep']} → {flag['next_ep']}: {flag['prev_cliff']} → {flag['next_hook']}")
            if flag['prev_cliff_image']:
                print(f"    Cliff: \"{flag['prev_cliff_image']}...\"")
        if cont['count'] > 5:
            print(f"    ...and {cont['count'] - 5} more")
        print(f"  {CYAN}Tip: MID-ACTION cliffs should typically resolve visually{RESET}")
    else:
        print(f"\nCONTINUITY FLAGS: No potential issues detected {GREEN}✓{RESET}")

    # Main-category runs (informational — not a violation if subtypes vary)
    mcr = results['soft_flags'].get('main_category_runs', {'count': 0, 'hook_violations': [], 'cliff_violations': []})
    if mcr['count'] > 0:
        print(f"\n{YELLOW}MAIN-CATEGORY RUNS (informational): {mcr['count']} long runs{RESET}")
        for v in mcr['hook_violations']:
            print(f"  Hook: {v['count']}x {v['type']} (Ep {v['start']}-{v['end']})")
        for v in mcr['cliff_violations']:
            print(f"  Cliff: {v['count']}x {v['type']} (Ep {v['start']}-{v['end']})")
        print(f"  {CYAN}Note: Not a violation if subtypes vary within the run{RESET}")
    elif has_subtypes:
        print(f"\nMAIN-CATEGORY RUNS: No long runs {GREEN}✓{RESET}")

    # Final result
    print(f"\n{'═' * 65}")
    if results['passed']:
        ep10_flag = 0 if ep10u['passed'] else 1
        soft_issues = wc['count'] + vl['count'] + bp['count'] + wm['count'] + ep10_flag + cont['count'] + mcr['count']
        if soft_issues > 0:
            print(f"{BOLD}{YELLOW}RESULT: PASS (with {soft_issues} warnings){RESET}")
            weak_eps = set()
            for ep in vl['episodes']:
                weak_eps.add(ep['episode'])
            for ep in wm['episodes']:
                weak_eps.add(ep['episode'])
            if weak_eps:
                ep_list = ','.join(map(str, sorted(list(weak_eps))[:10]))
                print(f"\n{CYAN}To address warnings:{RESET}")
                print(f"  /treatment {project_name} --plus {ep_list}")
        else:
            print(f"{BOLD}{GREEN}RESULT: PASS{RESET}")
    else:
        print(f"{BOLD}{RED}RESULT: FAIL{RESET}")
        print(f"\n{YELLOW}FIX REQUIRED:{RESET}")
        if not results['hard_gates']['coverage']['passed']:
            print(f"  - Add missing episodes to treatment.md")
        if not results['hard_gates']['metadata']['passed']:
            print(f"  - Add metadata lines (Sequence|Beat|Hook|Cliffhanger)")
        if not results['hard_gates']['the_moment']['passed']:
            print(f"  - Add THE MOMENT to each episode")
        if not results['hard_gates']['voice_seed']['passed']:
            print(f"  - Add VOICE SEED line to Episode 1")
        if not results['hard_gates']['key_episode_words']['passed']:
            print(f"  - Expand Ep 1 (125+ words), Ep 10/15 (100+ words)")
        if not results['hard_gates']['total_word_count']['passed']:
            twc = results['hard_gates']['total_word_count']
            if twc['count'] < 5000:
                print(f"  - Treatment too short ({twc['count']} words) - expand prose paragraphs")
            else:
                print(f"  - Treatment too long ({twc['count']} words) - tighten prose paragraphs")
        if not results['hard_gates']['cliffhanger_images']['passed']:
            print(f"  - Add [CLIFFHANGER: ...] image to each episode")
        if not results['hard_gates']['hook_ratio']['passed']:
            print(f"  - Adjust hook distribution to 70-85% SILENT")
        if not results['hard_gates']['cliffhanger_ratio']['passed']:
            print(f"  - Adjust cliffhanger distribution to 70-85% MID-ACTION")
        if not results['hard_gates']['pattern_variety']['passed']:
            print(f"  - Break up consecutive same-subtype patterns (e.g., 4x M-PT → vary with M-CF, M-CT)")
    print(f"{'═' * 65}\n")

    return results['passed']


def main():
    if len(sys.argv) < 2:
        print(f"Usage: python {sys.argv[0]} <project_path>")
        print(f"       python {sys.argv[0]} <project_path> --flag-weak")
        print(f"Example: python {sys.argv[0]} ./olympus")
        sys.exit(1)

    project_path = sys.argv[1]
    project_name = Path(project_path).name

    results = validate_treatment(project_path)

    if results['errors']:
        print(f"{RED}Errors:{RESET}")
        for err in results['errors']:
            print(f"  - {err}")
        sys.exit(1)

    passed = print_report(results, project_name)
    sys.exit(0 if passed else 1)


if __name__ == '__main__':
    main()
