#!/usr/bin/env python3
"""
Treatment Batch Validation Tool
Validates a batch of episodes in treatment.md during generation.

Usage:
    python validate_treatment_batch.py <project_path> <start_ep> <end_ep>
    python validate_treatment_batch.py ./leviathan 1 10
    python validate_treatment_batch.py ./leviathan 11 20

Purpose:
    Hard gate during batch-based treatment generation. Validates:
    1. Word count per batch (~900 words for 10 episodes)
    2. Running hook/cliffhanger ratios (track drift from 70-85% targets)
    3. Pattern violations (max 3 consecutive same type)
    4. Episode completeness (metadata, THE MOMENT, cliffhanger image)

Exit Codes:
    0 = PASS - proceed to next batch
    1 = FAIL - fix issues before continuing
"""

import sys
import os
import re
from pathlib import Path

# Import constants from shared module (reads from CONSTANTS.md)
try:
    from engine_constants import (
        TREATMENT_TOTAL_WORDS_MIN,
        TREATMENT_TOTAL_WORDS_MAX,
        TREATMENT_BATCH_SIZE,
        TOTAL_EPISODES,
        MAX_CONSECUTIVE_SAME_TYPE
    )
    # Derived values
    BATCH_SIZE = TREATMENT_BATCH_SIZE
    TARGET_TOTAL_WORDS = (TREATMENT_TOTAL_WORDS_MIN + TREATMENT_TOTAL_WORDS_MAX) // 2
    WORDS_PER_BATCH_TARGET = TARGET_TOTAL_WORDS / (TOTAL_EPISODES / BATCH_SIZE)
    BATCH_WORD_TOLERANCE = 150
except ImportError:
    # Fallback if engine_constants not available
    print("WARNING: Could not import engine_constants, using fallback values")
    BATCH_SIZE = 10
    TOTAL_EPISODES = 60
    TARGET_TOTAL_WORDS = 3500
    WORDS_PER_BATCH_TARGET = TARGET_TOTAL_WORDS / (TOTAL_EPISODES / BATCH_SIZE)
    BATCH_WORD_TOLERANCE = 150
    MAX_CONSECUTIVE_SAME_TYPE = 3

# ANSI colors
GREEN = '\033[92m'
RED = '\033[91m'
YELLOW = '\033[93m'
CYAN = '\033[96m'
BOLD = '\033[1m'
RESET = '\033[0m'

# Distribution targets (from CONSTANTS.md hook/cliffhanger ratio targets)
SILENT_HOOK_TARGET = (70, 85)     # 70-85% SILENT
MIDACTION_CLIFF_TARGET = (70, 85) # 70-85% MID-ACTION


def parse_treatment_batch(filepath, start_ep, end_ep):
    """Parse treatment.md and extract data for specified episode range."""
    if not os.path.exists(filepath):
        return None, f"File not found: {filepath}"

    with open(filepath, 'r', encoding='utf-8') as f:
        content = f.read()

    # Episode header pattern
    episode_header_pattern = re.compile(
        r'###\s*Episode\s*(\d+):\s*"([^"]+)"',
        re.IGNORECASE
    )

    # Valid hook/cliffhanger codes (subtypes + main categories)
    VALID_HOOK_CODES = {
        'S-VP': 'SILENT', 'S-SF': 'SILENT', 'S-UI': 'SILENT', 'S-CO': 'SILENT',
        'S-DE': 'SILENT', 'S-CT': 'SILENT', 'S-PV': 'SILENT',
        'D-PI': 'DIALOGUE', 'D-QU': 'DIALOGUE', 'D-DC': 'DIALOGUE', 'D-MC': 'DIALOGUE',
        'SILENT': 'SILENT', 'DIALOGUE': 'DIALOGUE',
    }
    VALID_CLIFF_CODES = {
        'M-PT': 'MID-ACTION', 'M-CT': 'MID-ACTION', 'M-CH': 'MID-ACTION',
        'M-PU': 'MID-ACTION', 'M-CF': 'MID-ACTION',
        'A-RE': 'AFTERMATH', 'A-CO': 'AFTERMATH', 'A-PS': 'AFTERMATH',
        'A-SI': 'AFTERMATH', 'A-CT': 'AFTERMATH', 'A-DE': 'AFTERMATH',
        'MID-ACTION': 'MID-ACTION', 'AFTERMATH': 'AFTERMATH',
    }
    hook_codes = '|'.join(sorted(VALID_HOOK_CODES.keys(), key=len, reverse=True))
    cliff_codes = '|'.join(sorted(VALID_CLIFF_CODES.keys(), key=len, reverse=True))

    # Metadata line pattern (accepts subtypes and main categories)
    metadata_pattern = re.compile(
        r'\*\*Sequence:\*\*\s*(\d+)\s*\|\s*\*\*Beat:\*\*\s*(\w+[-\w]*)\s*\|\s*\*\*Hook:\*\*\s*(' + hook_codes + r')\s*\|\s*\*\*Cliffhanger:\*\*\s*(' + cliff_codes + r')',
        re.IGNORECASE
    )

    # THE MOMENT pattern
    moment_pattern = re.compile(
        r'\*\*THE MOMENT:\*\*\s*(.+?)(?:\n|$)',
        re.IGNORECASE
    )

    # Cliffhanger image pattern
    cliffhanger_image_pattern = re.compile(
        r'\*\*\[CLIFFHANGER:\s*(.+?)\]\*\*',
        re.IGNORECASE
    )

    episodes = []
    all_episodes = []  # For cumulative tracking

    # Split content into episode blocks
    episode_blocks = re.split(r'(?=###\s*Episode\s*\d+)', content)

    for block in episode_blocks:
        header_match = episode_header_pattern.search(block)
        if not header_match:
            continue

        ep_num = int(header_match.group(1))
        title = header_match.group(2)

        episode_data = {
            'episode': ep_num,
            'title': title,
            'hook_type': None,
            'cliffhanger_type': None,
            'has_metadata': False,
            'has_moment': False,
            'has_cliffhanger_image': False,
            'prose_word_count': 0,
            'issues': []
        }

        # Extract metadata
        meta_match = metadata_pattern.search(block)
        if meta_match:
            raw_hook = meta_match.group(3).upper()
            raw_cliff = meta_match.group(4).upper()
            episode_data['hook_type'] = VALID_HOOK_CODES.get(raw_hook, raw_hook)
            episode_data['cliffhanger_type'] = VALID_CLIFF_CODES.get(raw_cliff, raw_cliff)
            episode_data['hook_subtype'] = raw_hook
            episode_data['cliff_subtype'] = raw_cliff
            episode_data['has_metadata'] = True

        # Extract THE MOMENT
        moment_match = moment_pattern.search(block)
        if moment_match:
            episode_data['has_moment'] = True

        # Extract cliffhanger image
        cliff_match = cliffhanger_image_pattern.search(block)
        if cliff_match:
            episode_data['has_cliffhanger_image'] = True

        # Extract prose word count
        lines = block.split('\n')
        prose_lines = []
        in_prose = False

        for line in lines:
            if episode_header_pattern.match(line):
                continue
            if '**Sequence:**' in line or '**Threads:**' in line or '**THE MOMENT:**' in line or '**VOICE SEED:**' in line:
                in_prose = True
                continue
            if '**[CLIFFHANGER:' in line:
                in_prose = False
                continue
            if in_prose and line.strip() and not line.startswith('---'):
                prose_lines.append(line.strip())

        if prose_lines:
            prose = ' '.join(prose_lines)
            episode_data['prose_word_count'] = len(prose.split())

        all_episodes.append(episode_data)

        # Only include in batch if within range
        if start_ep <= ep_num <= end_ep:
            episodes.append(episode_data)

    return {
        'batch_episodes': episodes,
        'all_episodes': all_episodes,
        'start_ep': start_ep,
        'end_ep': end_ep
    }, None


def validate_batch(data):
    """Validate treatment batch."""
    batch_eps = data['batch_episodes']
    all_eps = data['all_episodes']
    start_ep = data['start_ep']
    end_ep = data['end_ep']

    results = {
        'passed': True,
        'batch_range': f"{start_ep}-{end_ep}",
        'hard_gates': {},
        'warnings': []
    }

    # === HARD GATES ===

    # 1. Episode coverage in batch
    expected_eps = list(range(start_ep, end_ep + 1))
    found_eps = [ep['episode'] for ep in batch_eps]
    missing_eps = [e for e in expected_eps if e not in found_eps]

    results['hard_gates']['coverage'] = {
        'found': len(batch_eps),
        'expected': len(expected_eps),
        'passed': len(missing_eps) == 0,
        'missing': missing_eps
    }

    # 2. Episode completeness (metadata, THE MOMENT, cliffhanger image)
    incomplete = []
    for ep in batch_eps:
        issues = []
        if not ep['has_metadata']:
            issues.append('missing metadata')
        if not ep['has_moment']:
            issues.append('missing THE MOMENT')
        if not ep['has_cliffhanger_image']:
            issues.append('missing cliffhanger image')
        if issues:
            incomplete.append({'episode': ep['episode'], 'issues': issues})

    results['hard_gates']['completeness'] = {
        'passed': len(incomplete) == 0,
        'incomplete': incomplete
    }

    # 3. Batch word count
    batch_words = sum(ep['prose_word_count'] for ep in batch_eps)
    min_words = int(WORDS_PER_BATCH_TARGET - BATCH_WORD_TOLERANCE)
    max_words = int(WORDS_PER_BATCH_TARGET + BATCH_WORD_TOLERANCE)
    word_count_passed = min_words <= batch_words <= max_words

    results['hard_gates']['word_count'] = {
        'count': batch_words,
        'target': int(WORDS_PER_BATCH_TARGET),
        'range': f"{min_words}-{max_words}",
        'passed': word_count_passed
    }

    # 4. Pattern variety within batch — SUBTYPE consecutive check (hard gate)
    # Variety = subtype variety. A run of M-CF, M-PT, M-CT = fine (different kinds).
    # A run of 4x M-PT = violation (same subtype repeated).
    sorted_batch = sorted(batch_eps, key=lambda x: x['episode'])

    def check_consecutive(items, max_allowed=3):
        if not items:
            return True, []
        violations = []
        current_type = items[0]['type']
        count = 1
        start_ep = items[0]['ep']

        for i in range(1, len(items)):
            if items[i]['type'] == current_type:
                count += 1
            else:
                if count > max_allowed:
                    violations.append({
                        'type': current_type,
                        'start': start_ep,
                        'end': items[i - 1]['ep'],
                        'count': count
                    })
                current_type = items[i]['type']
                count = 1
                start_ep = items[i]['ep']

        # Check the final run
        if count > max_allowed:
            violations.append({
                'type': current_type,
                'start': start_ep,
                'end': items[-1]['ep'],
                'count': count
            })

        return len(violations) == 0, violations

    # Use subtypes for variety check (the real variety constraint)
    hook_sub_seq = [{'ep': ep['episode'], 'type': ep.get('hook_subtype', ep['hook_type'])} for ep in sorted_batch if ep.get('hook_subtype') or ep.get('hook_type')]
    cliff_sub_seq = [{'ep': ep['episode'], 'type': ep.get('cliff_subtype', ep['cliffhanger_type'])} for ep in sorted_batch if ep.get('cliff_subtype') or ep.get('cliffhanger_type')]

    hook_ok, hook_violations = check_consecutive(hook_sub_seq)
    cliff_ok, cliff_violations = check_consecutive(cliff_sub_seq)

    results['hard_gates']['pattern_variety'] = {
        'passed': hook_ok and cliff_ok,
        'hook_violations': hook_violations,
        'cliffhanger_violations': cliff_violations
    }

    # === CUMULATIVE TRACKING (warnings, not hard gates) ===

    # Calculate cumulative distribution up to and including this batch
    cumulative_eps = [ep for ep in all_eps if ep['episode'] <= end_ep]

    if cumulative_eps:
        # Cumulative hook distribution
        hooks = [ep['hook_type'] for ep in cumulative_eps if ep['hook_type']]
        silent_count = hooks.count('SILENT')
        total_hooks = len(hooks)
        silent_pct = (silent_count / total_hooks * 100) if total_hooks > 0 else 0

        # Cumulative cliffhanger distribution
        cliffs = [ep['cliffhanger_type'] for ep in cumulative_eps if ep['cliffhanger_type']]
        midaction_count = cliffs.count('MID-ACTION')
        total_cliffs = len(cliffs)
        midaction_pct = (midaction_count / total_cliffs * 100) if total_cliffs > 0 else 0

        results['cumulative'] = {
            'episodes_so_far': len(cumulative_eps),
            'total_words_so_far': sum(ep['prose_word_count'] for ep in cumulative_eps),
            'hooks': {
                'silent': silent_count,
                'dialogue': total_hooks - silent_count,
                'silent_pct': round(silent_pct, 1),
                'on_target': SILENT_HOOK_TARGET[0] <= silent_pct <= SILENT_HOOK_TARGET[1]
            },
            'cliffhangers': {
                'midaction': midaction_count,
                'aftermath': total_cliffs - midaction_count,
                'midaction_pct': round(midaction_pct, 1),
                'on_target': MIDACTION_CLIFF_TARGET[0] <= midaction_pct <= MIDACTION_CLIFF_TARGET[1]
            }
        }

        # Generate drift warnings
        if not results['cumulative']['hooks']['on_target']:
            if silent_pct < SILENT_HOOK_TARGET[0]:
                results['warnings'].append(f"SILENT hooks at {silent_pct:.1f}% (target: {SILENT_HOOK_TARGET[0]}-{SILENT_HOOK_TARGET[1]}%) - too few SILENT")
            else:
                results['warnings'].append(f"SILENT hooks at {silent_pct:.1f}% (target: {SILENT_HOOK_TARGET[0]}-{SILENT_HOOK_TARGET[1]}%) - too many SILENT")

        if not results['cumulative']['cliffhangers']['on_target']:
            if midaction_pct < MIDACTION_CLIFF_TARGET[0]:
                results['warnings'].append(f"MID-ACTION cliffs at {midaction_pct:.1f}% (target: {MIDACTION_CLIFF_TARGET[0]}-{MIDACTION_CLIFF_TARGET[1]}%) - too few MID-ACTION")
            else:
                results['warnings'].append(f"MID-ACTION cliffs at {midaction_pct:.1f}% (target: {MIDACTION_CLIFF_TARGET[0]}-{MIDACTION_CLIFF_TARGET[1]}%) - too many MID-ACTION")

    # Calculate overall pass
    results['passed'] = all([
        results['hard_gates']['coverage']['passed'],
        results['hard_gates']['completeness']['passed'],
        results['hard_gates']['word_count']['passed'],
        results['hard_gates']['pattern_variety']['passed']
    ])

    return results


def print_report(results, project_name, start_ep, end_ep):
    """Print formatted batch validation report."""
    batch_num = (start_ep - 1) // BATCH_SIZE + 1

    print(f"\n{'═' * 65}")
    print(f"{BOLD}TREATMENT BATCH VALIDATION: {project_name.upper()}{RESET}")
    print(f"Batch {batch_num}: Episodes {start_ep}-{end_ep}")
    print(f"{'═' * 65}\n")

    print(f"{BOLD}HARD GATES{RESET}")
    print(f"{'─' * 65}\n")

    # Coverage
    cov = results['hard_gates']['coverage']
    status = f"{GREEN}✓{RESET}" if cov['passed'] else f"{RED}✗{RESET}"
    print(f"EPISODE COVERAGE")
    print(f"  Episodes found: {cov['found']}/{cov['expected']} {status}")
    if cov['missing']:
        print(f"  {RED}Missing: Ep {', '.join(map(str, cov['missing']))}{RESET}")

    # Completeness
    comp = results['hard_gates']['completeness']
    status = f"{GREEN}✓{RESET}" if comp['passed'] else f"{RED}✗{RESET}"
    print(f"\nEPISODE COMPLETENESS")
    print(f"  All complete: {status}")
    if comp['incomplete']:
        for ep_data in comp['incomplete']:
            print(f"  {RED}Ep {ep_data['episode']}: {', '.join(ep_data['issues'])}{RESET}")

    # Word count
    wc = results['hard_gates']['word_count']
    status = f"{GREEN}✓{RESET}" if wc['passed'] else f"{RED}✗{RESET}"
    print(f"\nBATCH WORD COUNT")
    print(f"  Words: {wc['count']} (target: {wc['target']}, range: {wc['range']}) {status}")

    # Pattern variety
    pv = results['hard_gates']['pattern_variety']
    status = f"{GREEN}✓{RESET}" if pv['passed'] else f"{RED}✗{RESET}"
    print(f"\nPATTERN VARIETY (max {MAX_CONSECUTIVE_SAME_TYPE} consecutive same subtype)")
    print(f"  All subtypes varied: {status}")
    if pv['hook_violations']:
        for v in pv['hook_violations']:
            print(f"  {RED}Hook: {v['count']}x {v['type']} (Ep {v['start']}-{v['end']}){RESET}")
    if pv['cliffhanger_violations']:
        for v in pv['cliffhanger_violations']:
            print(f"  {RED}Cliff: {v['count']}x {v['type']} (Ep {v['start']}-{v['end']}){RESET}")

    # Cumulative tracking
    if 'cumulative' in results:
        print(f"\n{BOLD}CUMULATIVE TRACKING (through Ep {end_ep}){RESET}")
        print(f"{'─' * 65}\n")

        cum = results['cumulative']
        print(f"Total episodes: {cum['episodes_so_far']}")
        print(f"Total words: {cum['total_words_so_far']}")

        h = cum['hooks']
        status = f"{GREEN}✓{RESET}" if h['on_target'] else f"{YELLOW}⚠{RESET}"
        print(f"\nHOOKS: {h['silent']} SILENT ({h['silent_pct']}%) / {h['dialogue']} DIALOGUE {status}")

        c = cum['cliffhangers']
        status = f"{GREEN}✓{RESET}" if c['on_target'] else f"{YELLOW}⚠{RESET}"
        print(f"CLIFFS: {c['midaction']} MID-ACTION ({c['midaction_pct']}%) / {c['aftermath']} AFTERMATH {status}")

    # Warnings
    if results['warnings']:
        print(f"\n{BOLD}DRIFT WARNINGS{RESET}")
        print(f"{'─' * 65}\n")
        for warning in results['warnings']:
            print(f"  {YELLOW}⚠ {warning}{RESET}")

    # Final result
    print(f"\n{'═' * 65}")
    if results['passed']:
        if results['warnings']:
            print(f"{BOLD}{YELLOW}RESULT: PASS (with warnings){RESET}")
        else:
            print(f"{BOLD}{GREEN}RESULT: PASS{RESET}")

        next_batch_start = end_ep + 1
        if next_batch_start <= TOTAL_EPISODES:
            next_batch_end = min(next_batch_start + BATCH_SIZE - 1, TOTAL_EPISODES)
            print(f"\n{CYAN}PROCEED TO NEXT BATCH:{RESET}")
            print(f"  /treatment {project_name} --batch {next_batch_start}-{next_batch_end}")
        else:
            print(f"\n{GREEN}ALL BATCHES COMPLETE!{RESET}")
            print(f"  Run full validation: python3 /tools/validate_treatment.py ./{project_name}")
    else:
        print(f"{BOLD}{RED}RESULT: FAIL{RESET}")
        print(f"\n{YELLOW}FIX REQUIRED before proceeding:{RESET}")
        if not results['hard_gates']['coverage']['passed']:
            print(f"  - Add missing episodes")
        if not results['hard_gates']['completeness']['passed']:
            print(f"  - Complete episode metadata, THE MOMENT, and cliffhanger images")
        if not results['hard_gates']['word_count']['passed']:
            wc = results['hard_gates']['word_count']
            if wc['count'] < int(WORDS_PER_BATCH_TARGET - BATCH_WORD_TOLERANCE):
                print(f"  - Batch too short ({wc['count']} words) - expand prose")
            else:
                print(f"  - Batch too long ({wc['count']} words) - tighten prose")
        if not results['hard_gates']['pattern_variety']['passed']:
            print(f"  - Break up consecutive same-subtype patterns (e.g., 4x M-PT → vary with M-CF, M-CT)")
    print(f"{'═' * 65}\n")

    return results['passed']


def main():
    if len(sys.argv) < 4:
        print(f"Usage: python {sys.argv[0]} <project_path> <start_ep> <end_ep>")
        print(f"Example: python {sys.argv[0]} ./leviathan 1 10")
        print(f"         python {sys.argv[0]} ./leviathan 11 20")
        sys.exit(1)

    project_path = sys.argv[1]
    try:
        start_ep = int(sys.argv[2])
        end_ep = int(sys.argv[3])
    except ValueError:
        print(f"{RED}Error: start_ep and end_ep must be integers{RESET}")
        sys.exit(1)

    if start_ep < 1 or end_ep > TOTAL_EPISODES or start_ep > end_ep:
        print(f"{RED}Error: Invalid episode range. Must be 1-{TOTAL_EPISODES}{RESET}")
        sys.exit(1)

    project_name = Path(project_path).name
    treatment_path = Path(project_path) / 'treatment.md'

    data, err = parse_treatment_batch(treatment_path, start_ep, end_ep)
    if err:
        print(f"{RED}Error: {err}{RESET}")
        sys.exit(1)

    results = validate_batch(data)
    passed = print_report(results, project_name, start_ep, end_ep)
    sys.exit(0 if passed else 1)


if __name__ == '__main__':
    main()
