#!/usr/bin/python3
"""
Voice Contamination Detection

Detects unintentional voice issues at key batches (3, 6, 9, 12):
1. Contamination - Character A using Character B's idioms/patterns
2. Regression - Character unexpectedly reverting to earlier voice stage
3. Generic drift - Characters becoming interchangeable

This does NOT flag intentional voice evolution (mask cracking, relationship
deepening). Character voice SHOULD change over the series - this catches
UNINTENTIONAL drift.

Method:
1. Load character signature patterns from characters.md
2. Extract dialogue from current batch and previous batch
3. For each character, check:
   - Are they using another character's signature patterns? (contamination)
   - Did their voice metrics change suddenly from last batch? (regression)
   - Are characters becoming more similar to each other? (generic drift)

Usage: python3 baseline_comparison.py <project_path> <batch_number>
Example: python3 baseline_comparison.py ./leviathan 6

Triggered at: Batches 3, 6, 9, 12 (by save_checkpoint.py)

Exit codes:
- 0 = No contamination detected
- 1 = Contamination or regression detected (soft gate - recommends review)
- 2 = Error (missing files, etc.)
"""

import json
import sys
import re
from pathlib import Path
from collections import defaultdict

# Add engine tools to path for imports
_SCRIPT_DIR = Path(__file__).parent.resolve()
_ENGINE_TOOLS = _SCRIPT_DIR.parent.parent / 'tools'
if _ENGINE_TOOLS.exists():
    sys.path.insert(0, str(_ENGINE_TOOLS))

try:
    from engine_constants import (
        is_character_cue as _shared_is_character_cue,
        GENERATION_BATCH_SIZE,
    )
    _USE_SHARED = True
except ImportError:
    _USE_SHARED = False
    GENERATION_BATCH_SIZE = 5


def extract_dialogue_from_episode(episode_path):
    """
    Extract dialogue from a single episode file.
    Returns dict of {character_name: [list of lines]}
    """
    dialogue = defaultdict(list)

    if not episode_path.exists():
        return dialogue

    content = episode_path.read_text()

    # Look for Fountain-style character cues and dialogue
    lines = content.split('\n')
    current_character = None
    in_dialogue = False

    for i, line in enumerate(lines):
        stripped = line.strip()

        # Check for character cue using shared detection or fallback regex
        is_cue = False
        if _USE_SHARED:
            # Handle @ prefix (Fountain forced character cue)
            test_line = stripped.lstrip('@').strip() if stripped.startswith('@') else stripped
            is_cue = _shared_is_character_cue(test_line)
        else:
            is_cue = bool(stripped and re.match(r'^@?[A-Z][A-Z\s\']+$', stripped))

        if is_cue:
            current_character = stripped.lstrip('@').strip()
            # Remove parenthetical extensions for character name
            current_character = re.sub(r'\s*\([^)]+\)\s*', '', current_character).strip()
            in_dialogue = True
        elif current_character and in_dialogue and stripped:
            # Skip parentheticals
            if stripped.startswith('(') and stripped.endswith(')'):
                continue
            # Capture dialogue lines: uppercase start, quoted, or any line while in dialogue block
            if stripped[0].isupper() or stripped.startswith('"') or in_dialogue:
                dialogue[current_character].append(stripped)
        elif not stripped:
            # Blank line ends dialogue block
            in_dialogue = False

    return dialogue


def extract_batch_dialogue(project_path, batch_num):
    """
    Extract all dialogue from a batch (5 episodes).
    Returns dict of {character_name: [list of lines]}
    """
    episodes_dir = project_path / "episodes"
    start_ep = (batch_num - 1) * GENERATION_BATCH_SIZE + 1
    end_ep = batch_num * GENERATION_BATCH_SIZE

    all_dialogue = defaultdict(list)

    for ep_num in range(start_ep, end_ep + 1):
        ep_file = episodes_dir / f"ep_{ep_num:03d}.md"
        if ep_file.exists():
            ep_dialogue = extract_dialogue_from_episode(ep_file)
            for char, lines in ep_dialogue.items():
                all_dialogue[char].extend(lines)

    return all_dialogue


def load_character_patterns(project_path):
    """
    Load character signature patterns from characters.md.
    Returns dict of {character_name: {
        'signature_phrases': [...],
        'forbidden_phrases': [...],
        'idiom_keywords': [...]
    }}
    """
    patterns = {}

    voices_file = project_path / "bible" / "characters.md"
    if not voices_file.exists():
        return patterns

    content = voices_file.read_text()

    # Find character sections
    char_sections = re.findall(
        r'##\s+([A-Z][A-Za-z]+)\s*—.*?(?=##\s+[A-Z]|\Z)',
        content,
        re.DOTALL
    )

    # Better approach: split by ## headers
    sections = re.split(r'\n##\s+', content)

    for section in sections[1:]:  # Skip first (before first ##)
        # Get character name from first line
        first_line = section.split('\n')[0]
        name_match = re.match(r'([A-Z][A-Za-z]+)\s*—', first_line)
        if not name_match:
            continue

        char_name = name_match.group(1).upper()
        patterns[char_name] = {
            'signature_phrases': [],
            'forbidden_phrases': [],
            'idiom_keywords': []
        }

        # Extract quoted samples (signature phrases)
        quotes = re.findall(r'>\s*"([^"]+)"', section)
        for quote in quotes[:5]:
            patterns[char_name]['signature_phrases'].append(quote.lower())

        # Extract anti-patterns / forbidden
        forbidden_section = re.search(r'Anti-Patterns.*?(?=###|\Z)', section, re.DOTALL | re.IGNORECASE)
        if forbidden_section:
            forbidden = re.findall(r'-\s*([^\n]+)', forbidden_section.group(0))
            patterns[char_name]['forbidden_phrases'] = [f.lower() for f in forbidden[:5]]

        # Extract idiom keywords from speech pattern descriptions
        idiom_section = re.search(r'(?:IDIOM|Speech Pattern).*?(?=\*\*|\Z)', section, re.DOTALL | re.IGNORECASE)
        if idiom_section:
            keywords = re.findall(r'-\s*([^\n]+)', idiom_section.group(0))
            for kw in keywords[:5]:
                # Extract key terms
                words = re.findall(r'\b[a-z]{4,}\b', kw.lower())
                patterns[char_name]['idiom_keywords'].extend(words[:3])

    return patterns


def check_contamination(dialogue_by_char, char_patterns):
    """
    Check if characters are using each other's signature patterns.
    Returns list of (char_a, char_b, evidence) tuples.
    """
    contaminations = []

    char_names = list(char_patterns.keys())

    for char_a in char_names:
        if char_a not in dialogue_by_char:
            continue

        char_a_text = ' '.join(dialogue_by_char[char_a]).lower()

        for char_b in char_names:
            if char_a == char_b:
                continue

            # Check if char_a is using char_b's signature phrases
            for phrase in char_patterns[char_b].get('signature_phrases', []):
                # Look for partial matches (key words from the phrase)
                phrase_words = phrase.split()
                if len(phrase_words) >= 3:
                    # Check for 3+ word sequences
                    key_sequence = ' '.join(phrase_words[:3])
                    if key_sequence in char_a_text:
                        contaminations.append((
                            char_a,
                            char_b,
                            f"Using {char_b}'s phrase pattern: '{key_sequence}...'"
                        ))

            # Check if char_a is using char_b's idiom keywords frequently
            keyword_matches = 0
            for keyword in char_patterns[char_b].get('idiom_keywords', []):
                if keyword in char_a_text:
                    keyword_matches += 1

            if keyword_matches >= 3:
                contaminations.append((
                    char_a,
                    char_b,
                    f"Using {keyword_matches} of {char_b}'s idiom keywords"
                ))

    return contaminations


def check_generic_drift(dialogue_by_char):
    """
    Check if characters are becoming too similar (generic).
    Returns list of (char_a, char_b, similarity_score) tuples.
    """
    similarities = []

    chars = list(dialogue_by_char.keys())

    for i, char_a in enumerate(chars):
        for char_b in chars[i+1:]:
            text_a = ' '.join(dialogue_by_char[char_a]).lower()
            text_b = ' '.join(dialogue_by_char[char_b]).lower()

            if len(text_a) < 50 or len(text_b) < 50:
                continue

            # Simple vocabulary overlap check
            words_a = set(re.findall(r'\b[a-z]{4,}\b', text_a))
            words_b = set(re.findall(r'\b[a-z]{4,}\b', text_b))

            if not words_a or not words_b:
                continue

            overlap = len(words_a & words_b)
            total = len(words_a | words_b)
            similarity = overlap / total if total > 0 else 0

            # Flag if similarity is very high (>70%)
            if similarity > 0.7:
                similarities.append((char_a, char_b, similarity))

    return similarities


def check_regression(current_dialogue, previous_dialogue, char_patterns):
    """
    Check if any character's voice metrics changed dramatically from last batch.
    Returns list of (character, metric, change_description) tuples.
    """
    regressions = []

    for char in set(current_dialogue.keys()) & set(previous_dialogue.keys()):
        current_lines = current_dialogue[char]
        prev_lines = previous_dialogue[char]

        if len(current_lines) < 3 or len(prev_lines) < 3:
            continue

        # Check average line length change
        curr_avg = sum(len(l.split()) for l in current_lines) / len(current_lines)
        prev_avg = sum(len(l.split()) for l in prev_lines) / len(prev_lines)

        if prev_avg > 0:
            length_change = abs(curr_avg - prev_avg) / prev_avg
            if length_change > 0.4:  # 40% change in line length
                direction = "longer" if curr_avg > prev_avg else "shorter"
                regressions.append((
                    char,
                    "line_length",
                    f"Lines suddenly {direction} ({prev_avg:.1f} → {curr_avg:.1f} words)"
                ))

        # Check vocabulary shift (sudden use of new words)
        curr_words = set(re.findall(r'\b[a-z]{4,}\b', ' '.join(current_lines).lower()))
        prev_words = set(re.findall(r'\b[a-z]{4,}\b', ' '.join(prev_lines).lower()))

        if prev_words:
            new_word_ratio = len(curr_words - prev_words) / len(prev_words)
            if new_word_ratio > 0.5:  # 50% new vocabulary
                regressions.append((
                    char,
                    "vocabulary",
                    f"Sudden vocabulary shift ({new_word_ratio:.0%} new words)"
                ))

    return regressions


def main():
    if len(sys.argv) < 3:
        print("Usage: python3 baseline_comparison.py <project_path> <batch_number>")
        print("Example: python3 baseline_comparison.py ./leviathan 6")
        sys.exit(2)

    project_path = Path(sys.argv[1]).resolve()
    batch_num = int(sys.argv[2])

    if not project_path.exists():
        print(f"Error: Project path does not exist: {project_path}")
        sys.exit(2)

    # Only run at specific batches — VOICE_CHECK_BATCHES (not in CONSTANTS.md; defined here)
    if batch_num not in [3, 6, 9, 12]:
        print(f"Voice contamination check only runs at batches 3, 6, 9, 12. Current: {batch_num}")
        sys.exit(0)

    print(f"\n{'='*60}")
    print(f"VOICE CONTAMINATION CHECK")
    print(f"{'='*60}")
    print(f"Project: {project_path.name}")
    print(f"Batch: {batch_num}")

    # Load character patterns
    char_patterns = load_character_patterns(project_path)
    if not char_patterns:
        print("\nWARNING: No character patterns found in characters.md")
        print("Skipping pattern-based checks.")

    # Extract current batch dialogue
    print(f"\n--- Extracting Dialogue ---")
    current_dialogue = extract_batch_dialogue(project_path, batch_num)

    if not current_dialogue:
        print("WARNING: No dialogue found in current batch.")
        sys.exit(0)

    for char, lines in current_dialogue.items():
        print(f"  {char}: {len(lines)} lines")

    # Extract previous batch for regression check
    previous_dialogue = {}
    if batch_num > 1:
        previous_dialogue = extract_batch_dialogue(project_path, batch_num - 1)

    all_issues = []

    # Check contamination
    if char_patterns:
        print(f"\n--- Checking Contamination ---")
        contaminations = check_contamination(current_dialogue, char_patterns)
        for char_a, char_b, evidence in contaminations:
            issue = f"CONTAMINATION: {char_a} → {char_b}: {evidence}"
            all_issues.append(issue)
            print(f"  {issue}")

        if not contaminations:
            print("  No contamination detected")

    # Check generic drift
    print(f"\n--- Checking Generic Drift ---")
    similarities = check_generic_drift(current_dialogue)
    for char_a, char_b, sim in similarities:
        issue = f"GENERIC: {char_a} and {char_b} are {sim:.0%} similar (voices blurring)"
        all_issues.append(issue)
        print(f"  {issue}")

    if not similarities:
        print("  Voices remain distinct")

    # Check regression
    if previous_dialogue:
        print(f"\n--- Checking Regression ---")
        regressions = check_regression(current_dialogue, previous_dialogue, char_patterns)
        for char, metric, desc in regressions:
            issue = f"REGRESSION: {char} ({metric}): {desc}"
            all_issues.append(issue)
            print(f"  {issue}")

        if not regressions:
            print("  No sudden voice changes detected")

    # Report results
    print(f"\n{'='*60}")

    if all_issues:
        print(f"RESULT: ISSUES DETECTED ({len(all_issues)})")
        print(f"\nRECOMMENDED ACTION:")
        print(f"  Run qualitative voice review:")
        print(f"  /dramatic-qc {project_path.name} --mode post --batch {batch_num} --lens voice")
        print(f"\nThis is a SOFT GATE - generation can continue, but review is recommended.")
        print(f"{'='*60}\n")
        sys.exit(1)
    else:
        print(f"RESULT: NO ISSUES DETECTED")
        print(f"Character voices remain distinct and consistent.")
        print(f"{'='*60}\n")
        sys.exit(0)


if __name__ == "__main__":
    main()
