#!/usr/bin/env python3
"""
Kill Box Micro Format Validator

Validates episodes against Kill Box Micro (30-second) format constraints.
Same action/cliffhanger philosophy as Kill Box, compressed timeline.

Interface:
    validate_episode(episode_path, constants=None) -> dict
    validate_batch(episode_paths, constants=None) -> dict
"""

import re
from pathlib import Path


# =============================================================================
# CONSTANTS LOADING
# =============================================================================

def _parse_constants_from_md(filepath: Path) -> dict:
    """Parse CONSTANTS.md and extract key values."""
    content = filepath.read_text()
    constants = {}

    table_pattern = r'\|\s*`([A-Z_]+)`\s*\|\s*([^|]+)\s*\|'
    for match in re.finditer(table_pattern, content):
        name = match.group(1).strip()
        value_str = match.group(2).strip()

        # Parse value
        if value_str.endswith('%'):
            try:
                constants[name] = float(value_str.rstrip('%'))
            except ValueError:
                constants[name] = value_str
        elif value_str.lower() in ('true', 'false'):
            constants[name] = value_str.lower() == 'true'
        else:
            try:
                constants[name] = int(value_str)
            except ValueError:
                try:
                    constants[name] = float(value_str)
                except ValueError:
                    constants[name] = value_str

    # Parse TOTAL_WORDS range (e.g. "120-180")
    tw = constants.get('TOTAL_WORDS', '')
    if isinstance(tw, str) and '-' in tw:
        parts = tw.split('-')
        try:
            constants['TOTAL_MIN_WORDS'] = int(parts[0].strip())
            constants['TOTAL_MAX_WORDS'] = int(parts[1].strip())
        except ValueError:
            pass

    # Parse SHOT_RANGE (e.g. "3-10")
    sr = constants.get('SHOT_RANGE', '')
    if isinstance(sr, str) and '-' in sr:
        parts = sr.split('-')
        try:
            constants['SHOT_RANGE_MIN'] = int(parts[0].strip())
            constants['SHOT_RANGE_MAX'] = int(parts[1].strip())
        except ValueError:
            pass

    return constants


def _load_default_constants() -> dict:
    """Load constants from this format's CONSTANTS.md."""
    constants_path = Path(__file__).parent / 'CONSTANTS.md'
    if constants_path.exists():
        return _parse_constants_from_md(constants_path)
    # Hardcoded fallback
    return {
        'TOTAL_MIN_WORDS': 120,
        'TOTAL_MAX_WORDS': 180,
        'SPOKEN_WORDS_MAX': 40,
        'DIALOGUE_MAX_PERCENT': 25,
        'MAX_EXCHANGES': 3,
    }


# =============================================================================
# VALID VALUES
# =============================================================================

REQUIRED_BEATS = {'CONSEQUENCE', 'PIVOT', 'FREEZE', 'VOTE'}
VALID_RHYTHM_TAGS = {'FRENETIC', 'MEASURED', 'FLUID'}
VALID_CLIFFHANGER_TYPES = {'REVEAL', 'REVERSAL', 'CLOCK', 'DILEMMA',
                           'TICKING CLOCK'}  # Accept both "Clock" and "Ticking Clock"


# =============================================================================
# PARSING HELPERS
# =============================================================================

def _count_words(text: str) -> int:
    """Count words in text."""
    return len(text.split())


def _is_character_cue(line: str) -> bool:
    """Detect if a line is a character cue (speaker name) in Fountain format."""
    stripped = line.strip()
    if not stripped:
        return False
    if not stripped.isupper():
        return False
    if len(stripped) > 30:
        return False
    if stripped.startswith('.') or ':' in stripped:
        return False
    skip_words = [
        'ECU', 'CU', 'MCU', 'MS', 'WS', 'POV', 'SFX', 'VFX',
        'INSERT', 'CONTINUOUS', 'LATER', 'PULL BACK', 'CLICK',
        'CONSEQUENCE', 'PIVOT', 'FREEZE', 'VOTE',
    ]
    if stripped in skip_words or stripped.startswith('PULL'):
        return False
    if stripped.startswith('INT.') or stripped.startswith('EXT.'):
        return False
    return True


def _is_parenthetical(line: str) -> bool:
    """Detect parenthetical direction within dialogue."""
    stripped = line.strip()
    return stripped.startswith('(') and stripped.endswith(')')


def _count_spoken_words(text: str) -> int:
    """Count words in dialogue blocks (spoken words only)."""
    lines = text.split('\n')
    spoken_words = 0
    in_dialogue = False

    for line in lines:
        stripped = line.strip()

        if _is_character_cue(stripped):
            in_dialogue = True
            continue

        if in_dialogue:
            if _is_parenthetical(stripped):
                continue
            if stripped:
                spoken_words += _count_words(stripped)
            else:
                in_dialogue = False

    return spoken_words


def _count_exchanges(text: str) -> int:
    """Count dialogue exchanges (each character cue = 1)."""
    count = 0
    for line in text.split('\n'):
        if _is_character_cue(line.strip()):
            count += 1
    return count


# =============================================================================
# PUBLIC API
# =============================================================================

def validate_episode(episode_path: str, constants: dict = None) -> dict:
    """
    Validate a single Kill Box Micro episode.

    Args:
        episode_path: Path to the episode file.
        constants: Optional dict of constants. If None, loads from CONSTANTS.md.

    Returns:
        {valid: bool, errors: [], warnings: [], metrics: {}}
    """
    filepath = Path(episode_path)
    errors = []
    warnings = []
    metrics = {}

    if constants is None:
        constants = _load_default_constants()

    total_min = constants.get('TOTAL_MIN_WORDS', 120)
    total_max = constants.get('TOTAL_MAX_WORDS', 180)
    spoken_max = constants.get('SPOKEN_WORDS_MAX', 40)
    max_exchanges = constants.get('MAX_EXCHANGES', 3)

    if not filepath.exists():
        return {
            'valid': False,
            'errors': [f"File not found: {filepath}"],
            'warnings': [],
            'metrics': {},
        }

    content = filepath.read_text(encoding='utf-8')

    # -----------------------------------------------------------------
    # Beat sections: CONSEQUENCE, PIVOT, FREEZE, VOTE
    # -----------------------------------------------------------------
    beats_found = set()
    for beat in REQUIRED_BEATS:
        # Match # [timestamp] BEAT_NAME or ## BEAT_NAME or just the section
        pattern = rf'#{{1,3}}\s*(?:\[[\d:s\-\s]+\]\s*)?{re.escape(beat)}'
        if re.search(pattern, content, re.IGNORECASE):
            beats_found.add(beat)

    metrics['beats_found'] = sorted(beats_found)
    missing_beats = REQUIRED_BEATS - beats_found
    if missing_beats:
        errors.append(f"Missing required beats: {', '.join(sorted(missing_beats))}")

    # -----------------------------------------------------------------
    # Word counts: 120-180 total
    # -----------------------------------------------------------------
    total_words = _count_words(content)
    metrics['total_word_count'] = total_words

    if total_words < total_min:
        errors.append(f"Total word count too LOW: {total_words} (min: {total_min})")
    elif total_words > total_max:
        errors.append(f"Total word count too HIGH: {total_words} (max: {total_max})")

    # -----------------------------------------------------------------
    # Spoken words: <=40
    # -----------------------------------------------------------------
    spoken_words = _count_spoken_words(content)
    metrics['spoken_word_count'] = spoken_words

    if spoken_words > spoken_max:
        errors.append(f"Spoken words too HIGH: {spoken_words} (max: {spoken_max})")

    # -----------------------------------------------------------------
    # Exchange count
    # -----------------------------------------------------------------
    exchanges = _count_exchanges(content)
    metrics['exchange_count'] = exchanges

    if exchanges > max_exchanges:
        errors.append(f"Exchanges too HIGH: {exchanges} (max: {max_exchanges})")

    # -----------------------------------------------------------------
    # Rhythm tag: Frenetic, Measured, Fluid
    # -----------------------------------------------------------------
    rhythm_match = re.search(r'Rhythm\s*:\s*(\w+)', content, re.IGNORECASE)
    if rhythm_match:
        rhythm_tag = rhythm_match.group(1).upper()
        metrics['rhythm_tag'] = rhythm_tag
        if rhythm_tag not in VALID_RHYTHM_TAGS:
            errors.append(
                f"Invalid rhythm tag: {rhythm_tag} "
                f"(must be one of {', '.join(sorted(VALID_RHYTHM_TAGS))})"
            )
    else:
        errors.append(
            f"Missing rhythm tag (must be one of: {', '.join(sorted(VALID_RHYTHM_TAGS))})"
        )

    # -----------------------------------------------------------------
    # Cliffhanger type: Reveal, Reversal, Clock, Dilemma
    # -----------------------------------------------------------------
    cliff_match = re.search(
        r'Cliffhanger\s*(?:Type)?\s*:\s*(.+?)(?:\n|$)',
        content, re.IGNORECASE
    )
    if cliff_match:
        cliff_type = cliff_match.group(1).strip().upper()
        metrics['cliffhanger_type'] = cliff_type
        # Normalize: "TICKING CLOCK" -> "CLOCK" for comparison
        cliff_normalized = cliff_type.replace('TICKING ', '')
        if cliff_normalized not in VALID_CLIFFHANGER_TYPES and cliff_type not in VALID_CLIFFHANGER_TYPES:
            errors.append(
                f"Invalid cliffhanger type: {cliff_type} "
                f"(must be one of: Reveal, Reversal, Clock, Dilemma)"
            )
    else:
        errors.append(
            "Missing cliffhanger type (must be one of: Reveal, Reversal, Clock, Dilemma)"
        )

    return {
        'valid': len(errors) == 0,
        'errors': errors,
        'warnings': warnings,
        'metrics': metrics,
    }


def validate_batch(episode_paths: list, constants: dict = None) -> dict:
    """
    Validate a batch of Kill Box Micro episodes.

    Args:
        episode_paths: List of file paths to episode files.
        constants: Optional dict of constants. If None, loads from CONSTANTS.md.

    Returns:
        {valid: bool, episode_results: [...]}
    """
    if constants is None:
        constants = _load_default_constants()

    results = []
    all_valid = True

    for ep_path in episode_paths:
        result = validate_episode(str(ep_path), constants=constants)
        result['file'] = str(ep_path)
        results.append(result)
        if not result['valid']:
            all_valid = False

    return {
        'valid': all_valid,
        'episode_results': results,
    }
