#!/usr/bin/env python3
"""
Batch Analysis Tool
Analyzes a batch of generated episodes for quality metrics.

Usage:
    python analyze_batch.py <project_path> <batch_number>
    python analyze_batch.py <project_path> --all
    python analyze_batch.py ./leviathan 3

Checks:
- Word count (450-500 target, see CONSTANTS.md)
- Dialogue percentage (≤40%)
- Exchange count (≤8)
- Cliffhanger variety
- Hook variety
- Emotional beat presence
- Thread opportunities
"""

import sys
import re
from pathlib import Path
from collections import defaultdict

# Import constants and shared counting functions from shared module (reads from CONSTANTS.md)
try:
    from engine_constants import (
        WORD_COUNT_MIN, WORD_COUNT_MAX,
        DIALOGUE_MAX_PERCENT, MAX_EXCHANGES,
        GENERATION_BATCH_SIZE,
        count_words as _shared_count_words,
        parse_dialogue_blocks,
        count_dialogue_words as _shared_count_dialogue_words,
        count_exchanges as _shared_count_exchanges,
        calculate_dialogue_percent,
    )
    _USE_SHARED = True
except ImportError:
    # Fallback if engine_constants not available
    print("WARNING: Could not import engine_constants, using fallback values")
    WORD_COUNT_MIN = 450
    WORD_COUNT_MAX = 500
    DIALOGUE_MAX_PERCENT = 40
    MAX_EXCHANGES = 8
    GENERATION_BATCH_SIZE = 5
    _USE_SHARED = False

# ANSI colors
GREEN = '\033[92m'
RED = '\033[91m'
YELLOW = '\033[93m'
CYAN = '\033[96m'
BOLD = '\033[1m'
RESET = '\033[0m'


def get_batch_episodes(batch_num):
    """Get episode range for a batch number."""
    start = (batch_num - 1) * GENERATION_BATCH_SIZE + 1
    end = batch_num * GENERATION_BATCH_SIZE
    return list(range(start, end + 1))


def read_episode(project_path, ep_num):
    """Read an episode file."""
    ep_path = project_path / 'episodes' / f'ep_{ep_num:03d}.md'
    if not ep_path.exists():
        return None

    with open(ep_path, 'r', encoding='utf-8') as f:
        return f.read()


def count_words(content):
    """
    Count ALL words in episode content.
    Delegates to shared canonical implementation from engine_constants.
    """
    if _USE_SHARED:
        return _shared_count_words(content)
    return len(content.split())


def count_dialogue(content):
    """Count dialogue words and percentage.
    Delegates to shared canonical implementation from engine_constants."""
    if _USE_SHARED:
        total_words = _shared_count_words(content)
        blocks = parse_dialogue_blocks(content)
        dialogue_words = _shared_count_dialogue_words(blocks)
        percentage = calculate_dialogue_percent(total_words, dialogue_words)
        return dialogue_words, percentage

    # Fallback: original implementation
    content_stripped = re.sub(r'^---.*?---', '', content, flags=re.DOTALL)

    dialogue_pattern = re.compile(r'^[A-Z][A-Z\s]+\n(.+?)(?=\n\n|\n[A-Z]|\Z)', re.MULTILINE | re.DOTALL)
    dialogue_matches = dialogue_pattern.findall(content_stripped)

    dialogue_words = 0
    for match in dialogue_matches:
        dialogue_words += len(re.findall(r'\b\w+\b', match))

    total_words = len(content.split())
    percentage = (dialogue_words / total_words * 100) if total_words > 0 else 0

    return dialogue_words, percentage


def count_exchanges(content):
    """Count dialogue exchanges.
    Delegates to shared canonical implementation from engine_constants."""
    if _USE_SHARED:
        blocks = parse_dialogue_blocks(content)
        return _shared_count_exchanges(blocks)

    # Fallback: original implementation
    content_stripped = re.sub(r'^---.*?---', '', content, flags=re.DOTALL)
    exchanges = len(re.findall(r'^[A-Z][A-Z\s]+$', content_stripped, re.MULTILINE))
    return exchanges


def detect_cliffhanger_type(content):
    """Try to detect cliffhanger type from content."""
    # Look at last paragraph
    paragraphs = content.strip().split('\n\n')
    if not paragraphs:
        return 'UNKNOWN'

    last = paragraphs[-1].lower()

    # Mid-action indicators
    mid_action = ['finger tightens', 'reaches for', 'lunges', 'fires', 'pulls the', 'leaps', 'dives', 'swings']
    for indicator in mid_action:
        if indicator in last:
            return 'M'

    # Consequence indicators
    consequence = ['falls', 'crashes', 'explodes', 'collapses', 'screams', 'blood']
    for indicator in consequence:
        if indicator in last:
            return 'C'

    # Revelation indicators
    revelation = ['realizes', 'sees', 'discovers', 'the truth', 'it was', 'you are', 'i am']
    for indicator in revelation:
        if indicator in last:
            return 'R'

    # Aftermath (dialogue ending)
    if re.search(r'^[A-Z][A-Z\s]+$', paragraphs[-1].strip().split('\n')[0] if paragraphs[-1].strip() else ''):
        return 'A'

    return 'UNKNOWN'


def detect_hook_type(content):
    """Try to detect hook type from content."""
    # Look at first paragraph after metadata
    content = re.sub(r'^---.*?---', '', content, flags=re.DOTALL).strip()
    paragraphs = content.split('\n\n')

    if not paragraphs:
        return 'UNKNOWN'

    first = paragraphs[0].lower()

    # Dialogue hook
    if re.match(r'^[A-Z][A-Z\s]+$', paragraphs[0].strip().split('\n')[0] if paragraphs[0].strip() else ''):
        return 'DIALOGUE'

    # Action indicators
    action = ['runs', 'fires', 'ducks', 'slams', 'crashes', 'explosion', 'alarm']
    for indicator in action:
        if indicator in first:
            return 'ACTION'

    return 'SILENT'


def check_emotional_beat(content, ep_num):
    """Check if episode contains emotional beat marker."""
    emotional_episodes = [10, 15, 20, 26, 30, 32, 33, 36, 42, 45, 50, 59, 60]

    # Check if this episode should have a beat
    should_have = ep_num in emotional_episodes

    # Look for emotional indicators in content
    emotional_indicators = ['💔', 'emotional beat', 'first time', 'finally', 'real name', 'i love', 'forgive']
    has_indicator = any(ind.lower() in content.lower() for ind in emotional_indicators)

    return should_have, has_indicator


def analyze_episode(project_path, ep_num):
    """Analyze a single episode."""
    content = read_episode(project_path, ep_num)
    if not content:
        return None

    word_count = count_words(content)
    dialogue_words, dialogue_pct = count_dialogue(content)
    exchanges = count_exchanges(content)
    cliffhanger = detect_cliffhanger_type(content)
    hook = detect_hook_type(content)
    should_beat, has_beat = check_emotional_beat(content, ep_num)

    return {
        'episode': ep_num,
        'word_count': word_count,
        'dialogue_words': dialogue_words,
        'dialogue_pct': round(dialogue_pct, 1),
        'exchanges': exchanges,
        'cliffhanger_type': cliffhanger,
        'hook_type': hook,
        'should_have_emotional_beat': should_beat,
        'has_emotional_indicator': has_beat,
        'issues': []
    }


def check_patterns(episodes):
    """Check for pattern violations across episodes (4+ consecutive is violation per CONSTANTS.md)."""
    violations = []

    # Check cliffhanger patterns — 4+ consecutive is a violation (max 3 allowed)
    cliff_types = [e['cliffhanger_type'] for e in episodes]
    for i in range(len(cliff_types) - 3):
        if cliff_types[i] == cliff_types[i+1] == cliff_types[i+2] == cliff_types[i+3] and cliff_types[i] != 'UNKNOWN':
            violations.append(f"4+ consecutive {cliff_types[i]} cliffhangers (Ep {episodes[i]['episode']}-{episodes[i+3]['episode']})")

    # Check hook patterns — 4+ consecutive is a violation (max 3 allowed)
    hook_types = [e['hook_type'] for e in episodes]
    for i in range(len(hook_types) - 3):
        if hook_types[i] == hook_types[i+1] == hook_types[i+2] == hook_types[i+3] and hook_types[i] != 'UNKNOWN':
            violations.append(f"4+ consecutive {hook_types[i]} hooks (Ep {episodes[i]['episode']}-{episodes[i+3]['episode']})")

    return violations


def analyze_batch(project_path, batch_num):
    """Analyze a batch of episodes."""
    episodes = get_batch_episodes(batch_num)
    results = []

    for ep_num in episodes:
        analysis = analyze_episode(project_path, ep_num)
        if analysis:
            # Check for issues (using constants from CONSTANTS.md)
            if analysis['word_count'] < WORD_COUNT_MIN:
                analysis['issues'].append(f"Under word count ({analysis['word_count']}/{WORD_COUNT_MIN})")
            elif analysis['word_count'] > WORD_COUNT_MAX:
                analysis['issues'].append(f"Over word count ({analysis['word_count']}/{WORD_COUNT_MAX})")

            if analysis['dialogue_pct'] > DIALOGUE_MAX_PERCENT:
                analysis['issues'].append(f"Dialogue over {DIALOGUE_MAX_PERCENT}% ({analysis['dialogue_pct']}%)")

            if analysis['exchanges'] > MAX_EXCHANGES:
                analysis['issues'].append(f"Too many exchanges ({analysis['exchanges']}/{MAX_EXCHANGES})")

            if analysis['should_have_emotional_beat'] and not analysis['has_emotional_indicator']:
                analysis['issues'].append("Missing emotional beat (scheduled episode)")

            results.append(analysis)

    return results


def print_batch_report(results, batch_num, project_name):
    """Print batch analysis report."""
    print(f"\n{'═' * 70}")
    print(f"{BOLD}BATCH {batch_num} ANALYSIS: {project_name.upper()}{RESET}")
    print(f"{'═' * 70}\n")

    if not results:
        print(f"{RED}No episodes found for batch {batch_num}{RESET}")
        return

    # Summary stats
    word_counts = [r['word_count'] for r in results]
    dialogue_pcts = [r['dialogue_pct'] for r in results]
    exchange_counts = [r['exchanges'] for r in results]

    avg_words = sum(word_counts) / len(word_counts)
    avg_dialogue = sum(dialogue_pcts) / len(dialogue_pcts)
    avg_exchanges = sum(exchange_counts) / len(exchange_counts)

    print(f"{BOLD}SUMMARY:{RESET}")
    print(f"  Episodes analyzed: {len(results)}")
    print(f"  Avg word count:    {avg_words:.0f} (target: {WORD_COUNT_MIN}-{WORD_COUNT_MAX})")
    print(f"  Avg dialogue:      {avg_dialogue:.1f}% (max: {DIALOGUE_MAX_PERCENT}%)")
    print(f"  Avg exchanges:     {avg_exchanges:.1f} (max: {MAX_EXCHANGES})")

    # Pattern check
    pattern_violations = check_patterns(results)
    if pattern_violations:
        print(f"\n{RED}PATTERN VIOLATIONS:{RESET}")
        for v in pattern_violations:
            print(f"  ✗ {v}")
    else:
        print(f"\n{GREEN}✓ No pattern violations{RESET}")

    # Per-episode details
    print(f"\n{BOLD}EPISODE DETAILS:{RESET}\n")

    header = f"{'EP':>3} | {'WORDS':>5} | {'DLG%':>5} | {'EXCH':>4} | {'CLIFF':>5} | {'HOOK':>8} | {'STATUS'}"
    print(header)
    print('─' * 70)

    for r in results:
        ep = r['episode']
        words = r['word_count']
        dlg = r['dialogue_pct']
        exch = r['exchanges']
        cliff = r['cliffhanger_type']
        hook = r['hook_type']

        # Color coding (using constants, with 30-word buffer for yellow)
        word_color = GREEN if WORD_COUNT_MIN <= words <= WORD_COUNT_MAX else (YELLOW if WORD_COUNT_MIN - 30 <= words <= WORD_COUNT_MAX + 30 else RED)
        dlg_color = GREEN if dlg <= DIALOGUE_MAX_PERCENT else RED
        exch_color = GREEN if exch <= MAX_EXCHANGES else RED

        status = f"{GREEN}✓{RESET}" if not r['issues'] else f"{RED}✗{RESET}"

        print(f"{ep:3} | {word_color}{words:5}{RESET} | {dlg_color}{dlg:5.1f}{RESET} | {exch_color}{exch:4}{RESET} | {cliff:5} | {hook:8} | {status}")

        if r['issues']:
            for issue in r['issues']:
                print(f"      {YELLOW}└─ {issue}{RESET}")

    # Cliffhanger distribution
    cliff_counts = defaultdict(int)
    hook_counts = defaultdict(int)
    for r in results:
        cliff_counts[r['cliffhanger_type']] += 1
        hook_counts[r['hook_type']] += 1

    print(f"\n{BOLD}DISTRIBUTIONS:{RESET}")
    print(f"  Cliffhangers: {dict(cliff_counts)}")
    print(f"  Hooks: {dict(hook_counts)}")

    # Issues summary
    all_issues = [i for r in results for i in r['issues']]
    if all_issues:
        print(f"\n{BOLD}{RED}ISSUES TO ADDRESS ({len(all_issues)}):{RESET}")
        for issue in all_issues:
            print(f"  - {issue}")
    else:
        print(f"\n{BOLD}{GREEN}✓ ALL EPISODES PASS{RESET}")

    print(f"\n{'═' * 70}\n")

    return len(all_issues) == 0 and len(pattern_violations) == 0


def main():
    if len(sys.argv) < 3:
        print(f"Usage: python {sys.argv[0]} <project_path> <batch_number|--all>")
        print(f"Example: python {sys.argv[0]} ./leviathan 3")
        sys.exit(1)

    project_path = Path(sys.argv[1])
    batch_arg = sys.argv[2]

    if batch_arg == '--all':
        # Analyze all batches
        all_passed = True
        for batch_num in range(1, 13):  # 12 batches for 60 episodes
            results = analyze_batch(project_path, batch_num)
            if results:
                passed = print_batch_report(results, batch_num, project_path.name)
                if not passed:
                    all_passed = False

        sys.exit(0 if all_passed else 1)
    else:
        batch_num = int(batch_arg)
        results = analyze_batch(project_path, batch_num)
        passed = print_batch_report(results, batch_num, project_path.name)
        sys.exit(0 if passed else 1)


if __name__ == '__main__':
    main()
