#!/usr/bin/env python3
"""
engine_constants.py - Shared Constants Parser

Reads constants from /CONSTANTS.md to ensure all Python scripts
use the same source of truth.

Usage:
    from engine_constants import CONSTANTS

    word_min = CONSTANTS['WORD_COUNT_MIN']
    word_max = CONSTANTS['WORD_COUNT_MAX']
    dialogue_max = CONSTANTS['DIALOGUE_MAX_PERCENT']

Or import specific values:
    from engine_constants import (
        WORD_COUNT_MIN, WORD_COUNT_MAX, DIALOGUE_MAX_PERCENT,
        MAX_EXCHANGES, MAX_CONSECUTIVE_SAME_TYPE
    )

LLM helpers:
    from engine_constants import (
        ANTHROPIC_HAIKU, ANTHROPIC_SONNET, ANTHROPIC_OPUS,
        get_anthropic_client, parse_llm_field,
    )
"""

import logging
import os
import re
from pathlib import Path
from typing import Dict, Any, Optional

logger = logging.getLogger(__name__)

# Find the CONSTANTS.md file relative to this script
def _find_constants_file() -> Path:
    """Locate CONSTANTS.md in the engine directory."""
    script_dir = Path(__file__).parent

    # Try CONSTANTS.md (if we're in tools/)
    if script_dir.name == 'tools':
        constants_path = script_dir.parent / 'CONSTANTS.md'
        if constants_path.exists():
            return constants_path

    # Try from current directory
    cwd = Path.cwd()
    candidates = [
        cwd / 'CONSTANTS.md',
        cwd.parent / 'CONSTANTS.md',
    ]

    for candidate in candidates:
        if candidate.exists():
            return candidate

    raise FileNotFoundError(
        "Could not find CONSTANTS.md. Expected at /CONSTANTS.md"
    )


def _parse_constants_md(filepath: Path) -> Dict[str, Any]:
    """
    Parse CONSTANTS.md and extract key values.

    Looks for patterns like:
    - | `WORD_COUNT_DEFAULT` | 450-500 |
    - | `DIALOGUE_MAX_PERCENT` | 40% |
    """
    content = filepath.read_text()
    constants = {}

    # Pattern for table rows with constant names
    # Matches: | `CONSTANT_NAME` | value | ... |
    table_pattern = r'\|\s*`([A-Z_]+)`\s*\|\s*([^|]+)\s*\|'

    for match in re.finditer(table_pattern, content):
        name = match.group(1).strip()
        value_str = match.group(2).strip()

        # Parse the value
        value = _parse_value(value_str)
        constants[name] = value

    # Also extract some derived/computed values that aren't in tables
    # but are useful for scripts

    # Word count range (450-500)
    if 'WORD_COUNT' not in constants:
        wc_match = re.search(r'Word count.*?(\d+)-(\d+)', content, re.IGNORECASE)
        if wc_match:
            print(f"WARNING: WORD_COUNT not found as table constant, using fallback regex parse: {wc_match.group(1)}-{wc_match.group(2)}")
            constants['WORD_COUNT_MIN'] = int(wc_match.group(1))
            constants['WORD_COUNT_MAX'] = int(wc_match.group(2))
    else:
        # Parse from WORD_COUNT if it's a range
        wc_val = constants.get('WORD_COUNT', '')
        if isinstance(wc_val, str) and '-' in wc_val:
            parts = wc_val.split('-')
            constants['WORD_COUNT_MIN'] = int(parts[0])
            constants['WORD_COUNT_MAX'] = int(parts[1])

    # Parse validation range from VALIDATION_WORD_DEFAULT (e.g., "450-500")
    val_default = constants.get('VALIDATION_WORD_DEFAULT', '')
    if isinstance(val_default, str) and '-' in val_default:
        val_parts = val_default.split('-')
        constants['VALIDATION_WORD_MIN'] = int(val_parts[0])
        constants['VALIDATION_WORD_MAX'] = int(val_parts[1])

    # Max consecutive (extract from the rule description)
    consec_match = re.search(r'MAX_CONSECUTIVE_SAME_TYPE.*?\|\s*(\d+)', content)
    if consec_match:
        constants['MAX_CONSECUTIVE_SAME_TYPE'] = int(consec_match.group(1))
    else:
        print("WARNING: Could not parse MAX_CONSECUTIVE_SAME_TYPE from CONSTANTS.md, using fallback value 3")
        constants['MAX_CONSECUTIVE_SAME_TYPE'] = 3

    # Parse validation ranges from Pattern Distribution table
    # Format: | `HOOK_SILENT_PERCENT` | 80% | 70-85% | ... |
    # We need the third column (Validation Range) for specific constants
    validation_range_pattern = r'\|\s*`(HOOK_SILENT_PERCENT|CLIFFHANGER_MIDACTION_PERCENT)`\s*\|\s*[^|]+\|\s*(\d+)-(\d+)%\s*\|'
    for match in re.finditer(validation_range_pattern, content):
        const_name = match.group(1).strip()
        val_min = int(match.group(2))
        val_max = int(match.group(3))
        if const_name == 'HOOK_SILENT_PERCENT':
            constants['HOOK_SILENT_VALIDATION_MIN'] = val_min
            constants['HOOK_SILENT_VALIDATION_MAX'] = val_max
        elif const_name == 'CLIFFHANGER_MIDACTION_PERCENT':
            constants['CLIFFHANGER_MIDACTION_VALIDATION_MIN'] = val_min
            constants['CLIFFHANGER_MIDACTION_VALIDATION_MAX'] = val_max

    # Parse Treatment Word Counts by Beat Type table
    # Format: | BEAT_TYPE | 40-55 | notes |
    treatment_word_counts = {}
    beat_type_pattern = r'\|\s*(SETUP|COMPLICATION|CATALYST|LOCK-IN|COLLISION|CRISIS|REVELATION|CLIMAX|RESOLUTION)\s*\|\s*(\d+)-(\d+)\s*\|'
    for match in re.finditer(beat_type_pattern, content):
        beat_type = match.group(1).strip()
        word_min = int(match.group(2))
        word_max = int(match.group(3))
        treatment_word_counts[beat_type] = (word_min, word_max)
    constants['TREATMENT_WORD_COUNT_RANGES'] = treatment_word_counts

    # Parse Key Episode Treatment Words table
    # Format: | 1 | 80-100 | reason |
    key_episode_words = {}
    key_ep_section = re.search(r'Key Episode Treatment Words.*?(?=\n---|\n##|\Z)', content, re.DOTALL)
    if key_ep_section:
        key_ep_pattern = r'\|\s*(\d+)\s*\|\s*(\d+)-(\d+)\s*\|'
        for match in re.finditer(key_ep_pattern, key_ep_section.group(0)):
            ep_num = int(match.group(1))
            word_min = int(match.group(2))
            word_max = int(match.group(3))
            key_episode_words[ep_num] = (word_min, word_max)
    constants['KEY_EPISODE_WORD_COUNTS'] = key_episode_words

    return constants


def _parse_value(value_str: str) -> Any:
    """Parse a value string into appropriate Python type."""
    value_str = value_str.strip()

    # Remove trailing notes like "(see CONSTANTS.md)"
    value_str = re.sub(r'\s*\(.*\)\s*$', '', value_str)

    # Percentage (40%)
    if value_str.endswith('%'):
        try:
            return int(value_str[:-1])
        except ValueError:
            return value_str

    # Range (e.g., 450-500)
    if re.match(r'^\d+-\d+$', value_str):
        return value_str  # Keep as string for ranges

    # Integer
    if re.match(r'^\d+$', value_str):
        return int(value_str)

    # Float
    if re.match(r'^\d+\.\d+$', value_str):
        return float(value_str)

    # Everything else stays as string
    return value_str


# =============================================================================
# FORMAT-AWARE CONSTANTS LOADER
# =============================================================================

# Recoil engine root (parent of tools/)
_ENGINE_ROOT = Path(__file__).parent.parent.resolve()

# Cache for format constants (keyed by format name)
_format_constants_cache: Dict[str, Dict[str, Any]] = {}


def _parse_format_value(value_str: str) -> Any:
    """
    Parse a value string from a format CONSTANTS.md into appropriate Python type.

    Handles:
    - Quoted strings → strip quotes
    - Parenthetical notes → strip them
    - "true"/"false" → bool
    - Percentages → int (without % sign)
    - Integers → int
    - Floats → float
    - Ranges like "450-500" → kept as string
    - Everything else → string
    """
    value_str = value_str.strip()

    # Strip surrounding quotes
    if (value_str.startswith('"') and value_str.endswith('"')) or \
       (value_str.startswith("'") and value_str.endswith("'")):
        value_str = value_str[1:-1].strip()

    # Remove trailing parenthetical notes like "(see FORMAT.md)"
    value_str = re.sub(r'\s*\([^)]*\)\s*$', '', value_str).strip()

    # Booleans
    if value_str.lower() == 'true':
        return True
    if value_str.lower() == 'false':
        return False

    # Percentage (40%)
    if value_str.endswith('%'):
        try:
            return int(value_str[:-1])
        except ValueError:
            return value_str

    # Range (e.g., 450-500) — keep as string
    if re.match(r'^\d+-\d+$', value_str):
        return value_str

    # Integer
    if re.match(r'^\d+$', value_str):
        return int(value_str)

    # Float
    if re.match(r'^\d+\.\d+$', value_str):
        return float(value_str)

    # Everything else stays as string
    return value_str


def _parse_format_constants_md(filepath: Path) -> Dict[str, Any]:
    """
    Parse a format-specific CONSTANTS.md file.

    Looks for markdown table rows matching: | `KEY` | VALUE | ...
    Extracts all backtick-wrapped constant names and their values.
    """
    content = filepath.read_text()
    constants: Dict[str, Any] = {}

    # Match table rows: | `CONSTANT_NAME` | value | (optional extra columns)
    table_pattern = r'\|\s*`([A-Z_][A-Z0-9_]*)`\s*\|\s*([^|]+?)\s*\|'

    for match in re.finditer(table_pattern, content):
        name = match.group(1).strip()
        value_str = match.group(2).strip()
        constants[name] = _parse_format_value(value_str)

    return constants


def load_format_constants(format_name: str) -> Dict[str, Any]:
    """
    Load constants for a specific format from formats/{format_name}/CONSTANTS.md.

    Args:
        format_name: Format directory name (e.g., "kill_box", "puzzle_box")

    Returns:
        Dict of parsed constants. Falls back to root CONSTANTS.md if
        format-specific file not found.

    Caches after first load per format name.
    """
    if format_name in _format_constants_cache:
        return _format_constants_cache[format_name]

    format_constants_path = _ENGINE_ROOT / 'formats' / format_name / 'CONSTANTS.md'

    if format_constants_path.exists():
        result = _parse_format_constants_md(format_constants_path)
    else:
        # Fall back to root CONSTANTS.md (backward compat)
        root_constants_path = _ENGINE_ROOT / 'CONSTANTS.md'
        if root_constants_path.exists():
            result = _parse_format_constants_md(root_constants_path)
        else:
            result = {}

    _format_constants_cache[format_name] = result
    return result


# Load constants on module import - FAIL HARD if not found
_CONSTANTS_FILE = _find_constants_file()
CONSTANTS = _parse_constants_md(_CONSTANTS_FILE)

# Export commonly used values as module-level constants
# All values come from CONSTANTS.md - no fallbacks
WORD_COUNT_MIN = CONSTANTS['WORD_COUNT_MIN']
WORD_COUNT_MAX = CONSTANTS['WORD_COUNT_MAX']
VALIDATION_WORD_MIN = CONSTANTS.get('VALIDATION_WORD_MIN', WORD_COUNT_MIN)
VALIDATION_WORD_MAX = CONSTANTS.get('VALIDATION_WORD_MAX', WORD_COUNT_MAX)

DIALOGUE_MAX_PERCENT = CONSTANTS['DIALOGUE_MAX_PERCENT']
MAX_EXCHANGES = CONSTANTS['MAX_EXCHANGES']
MAX_ACTION_BLOCK_LINES = CONSTANTS['MAX_ACTION_BLOCK_LINES']

# Pattern distribution
HOOK_SILENT_PERCENT = CONSTANTS['HOOK_SILENT_PERCENT']
CLIFFHANGER_MIDACTION_PERCENT = CONSTANTS['CLIFFHANGER_MIDACTION_PERCENT']
MAX_CONSECUTIVE_SAME_TYPE = CONSTANTS['MAX_CONSECUTIVE_SAME_TYPE']

# Pattern validation ranges (parsed from CONSTANTS.md → Pattern Distribution → Validation Range column)
HOOK_SILENT_VALIDATION_MIN = CONSTANTS['HOOK_SILENT_VALIDATION_MIN']
HOOK_SILENT_VALIDATION_MAX = CONSTANTS['HOOK_SILENT_VALIDATION_MAX']
CLIFFHANGER_MIDACTION_VALIDATION_MIN = CONSTANTS['CLIFFHANGER_MIDACTION_VALIDATION_MIN']
CLIFFHANGER_MIDACTION_VALIDATION_MAX = CONSTANTS['CLIFFHANGER_MIDACTION_VALIDATION_MAX']

# Treatment constraints - parse from CONSTANTS.md ranges
_treatment_total = CONSTANTS['TREATMENT_TOTAL_WORDS']
if isinstance(_treatment_total, str) and '-' in _treatment_total:
    _tt_parts = _treatment_total.split('-')
    TREATMENT_TOTAL_WORDS_MIN = int(_tt_parts[0])
    TREATMENT_TOTAL_WORDS_MAX = int(_tt_parts[1])
else:
    raise ValueError(f"Could not parse TREATMENT_TOTAL_WORDS range: {_treatment_total}")

_treatment_avg = CONSTANTS['TREATMENT_AVG_PER_EPISODE']
if isinstance(_treatment_avg, str) and '-' in _treatment_avg:
    _ta_parts = _treatment_avg.split('-')
    TREATMENT_AVG_PER_EPISODE_MIN = int(_ta_parts[0])
    TREATMENT_AVG_PER_EPISODE_MAX = int(_ta_parts[1])
else:
    raise ValueError(f"Could not parse TREATMENT_AVG_PER_EPISODE range: {_treatment_avg}")

TREATMENT_BATCH_SIZE = CONSTANTS['TREATMENT_BATCH_SIZE']

# Treatment word count ranges by beat type (from CONSTANTS.md table)
TREATMENT_WORD_COUNT_RANGES = CONSTANTS['TREATMENT_WORD_COUNT_RANGES']

# Key episode word count overrides (from CONSTANTS.md table)
KEY_EPISODE_WORD_COUNTS = CONSTANTS['KEY_EPISODE_WORD_COUNTS']

# Tolerance for treatment word counts (not in CONSTANTS.md, derived value)
TREATMENT_WORD_COUNT_TOLERANCE = 10

# Generation constraints
GENERATION_BATCH_SIZE = CONSTANTS['GENERATION_BATCH_SIZE']
TOTAL_EPISODES = CONSTANTS['TOTAL_EPISODES']
TOTAL_BATCHES = TOTAL_EPISODES // GENERATION_BATCH_SIZE

# Pilot window
PILOT_EPISODE_COUNT = CONSTANTS['PILOT_EPISODE_COUNT']
PATTERN_VARIETY_STARTS = CONSTANTS['PATTERN_VARIETY_STARTS']
PAYWALL_EPISODE = CONSTANTS['PAYWALL_EPISODE']

# Structural beats - derived from Emotional Beat Schedule in CONSTANTS.md
# These are fixed by story structure, not configurable
PLOT_POINT_1 = 15
MIDPOINT = 30
ALL_IS_LOST = 45
RESOLUTION = 61

# Voice contamination check batches - derived from CONSTANTS.md Voice Contamination Checkpoints
# Batches 3, 6, 9, 12, 13 = episodes 11-15, 26-30, 41-45, 56-60, 61
VOICE_CHECK_BATCHES = [3, 6, 9, 12, 13]

# Thread continuity constants - from CONSTANTS.md Thread Continuity section
STALE_THRESHOLD = CONSTANTS.get('STALE_THRESHOLD', 15)
MIN_THREAD_COUNT = CONSTANTS.get('MIN_THREAD_COUNT', 6)

# Goal-backward verification checkpoints
GOAL_BACKWARD_BATCHES = [3, 6, 9, 12, 13]  # Batches 3 (ep15), 6 (ep30), 9 (ep45), 12 (ep60), 13 (ep61)


# =============================================================================
# SHARED WORD & DIALOGUE COUNTING — Single Source of Truth
# =============================================================================
# ALL scripts MUST import these functions instead of implementing their own.
# Word counting: count ALL words in file (including metadata, headers, etc.)
# Dialogue detection: Fountain-style character cue parser.

# Non-character ALL-CAPS words to exclude from character cue detection
_CHARACTER_CUE_EXCLUDE = frozenset({
    'INT', 'EXT', 'CUT', 'FADE', 'THE', 'AND', 'BUT', 'THEN',
    'ECU', 'CU', 'MCU', 'MS', 'WS', 'POV', 'SFX', 'VFX',
    'INSERT', 'CONTINUOUS', 'LATER', 'NIGHT', 'DAY', 'MORNING',
    'EVENING', 'PULL BACK', 'CLICK',
})


def count_words(text: str) -> int:
    """
    Count ALL words in text. No stripping of metadata or formatting.
    This is the canonical word counting function for the entire engine.

    Counts everything: headers, metadata, Kill Box timestamps, annotations,
    markdown formatting tokens, cliffhanger/hook type labels — everything.

    This is why the word count target is 450-500 (not lower).
    """
    return len(text.split())


def is_character_cue(line: str) -> bool:
    """
    Detect if a line is a character cue (speaker name) in Fountain format.
    Character cues are ALL CAPS, possibly with parenthetical extensions.
    Canonical implementation — all scripts must use this.
    """
    stripped = line.strip()
    if not stripped:
        return False

    # Remove parenthetical (V.O.), (CONT'D), etc.
    base = re.sub(r'\s*\([^)]+\)\s*', '', stripped)
    if not base:
        return False

    # Character cues: all uppercase letters, spaces, hyphens, apostrophes
    if re.match(r'^[A-Z][A-Z\s\'\-]+$', base):
        if base.split()[0] not in _CHARACTER_CUE_EXCLUDE:
            return True

    return False


def is_parenthetical(line: str) -> bool:
    """Detect parenthetical direction within dialogue."""
    stripped = line.strip()
    return bool(stripped) and stripped.startswith('(') and stripped.endswith(')')


def parse_dialogue_blocks(text: str) -> list:
    """
    Parse episode text and return list of (speaker, dialogue_text) tuples.
    Uses Fountain-style character cue detection.
    Canonical implementation — all scripts must use this.

    Args:
        text: Full episode file content (including metadata)

    Returns:
        List of (character_name, dialogue_text) tuples
    """
    lines = text.split('\n')

    dialogue_blocks = []
    current_character = None
    current_dialogue = []
    in_dialogue = False

    for line in lines:
        stripped = line.strip()

        if is_character_cue(stripped):
            # Save previous dialogue block
            if current_character and current_dialogue:
                dialogue_blocks.append((current_character, '\n'.join(current_dialogue)))

            # Start new dialogue block
            current_character = re.sub(r'\s*\([^)]+\)\s*', '', stripped).strip()
            current_dialogue = []
            in_dialogue = True

        elif in_dialogue:
            if is_parenthetical(stripped):
                # Skip parentheticals from dialogue word count
                continue
            if stripped:
                # Check if this is actually another character cue or scene header
                if stripped.isupper() and len(stripped) < 30:
                    # Could be a character cue that didn't match, or a transition
                    if re.match(r'^(INT\.|EXT\.|INT/EXT\.)', stripped) or stripped.endswith(':'):
                        # Scene header or transition — end dialogue
                        if current_character and current_dialogue:
                            dialogue_blocks.append((current_character, '\n'.join(current_dialogue)))
                            current_character = None
                            current_dialogue = []
                        in_dialogue = False
                        continue
                current_dialogue.append(stripped)
            else:
                # Empty line ends dialogue
                if current_character and current_dialogue:
                    dialogue_blocks.append((current_character, '\n'.join(current_dialogue)))
                    current_character = None
                    current_dialogue = []
                in_dialogue = False

    # Don't forget trailing dialogue block
    if current_character and current_dialogue:
        dialogue_blocks.append((current_character, '\n'.join(current_dialogue)))

    return dialogue_blocks


def count_dialogue_words(dialogue_blocks: list) -> int:
    """
    Count total words across all dialogue blocks.
    Uses the same word counting method as count_words (split-based).
    Canonical implementation.

    Args:
        dialogue_blocks: Output from parse_dialogue_blocks()

    Returns:
        Total dialogue word count
    """
    return sum(len(d.split()) for _, d in dialogue_blocks)


def count_exchanges(dialogue_blocks: list) -> int:
    """
    Count dialogue exchanges. Each character cue = one exchange.
    Canonical implementation.

    Args:
        dialogue_blocks: Output from parse_dialogue_blocks()

    Returns:
        Number of exchanges
    """
    return len(dialogue_blocks)


def calculate_dialogue_percent(total_words: int, dialogue_words: int) -> float:
    """
    Calculate dialogue as percentage of total words.
    Canonical implementation.

    Args:
        total_words: From count_words()
        dialogue_words: From count_dialogue_words()

    Returns:
        Dialogue percentage (0.0 - 100.0)
    """
    if total_words == 0:
        return 0.0
    return (dialogue_words / total_words) * 100


def get_word_count_range() -> tuple:
    """
    Get word count range.

    Returns:
        (min, max) tuple
    """
    return (WORD_COUNT_MIN, WORD_COUNT_MAX)


def get_validation_range() -> tuple:
    """
    Get validation word count range.

    Returns:
        (min, max) tuple
    """
    return (VALIDATION_WORD_MIN, VALIDATION_WORD_MAX)


# =============================================================================
# ANTHROPIC MODEL CONSTANTS & LLM HELPERS
# =============================================================================
# Centralized model IDs — change here, not in 5 separate files.

ANTHROPIC_HAIKU = "claude-haiku-4-5-20251001"
ANTHROPIC_SONNET = "claude-sonnet-4-6"
ANTHROPIC_OPUS = "claude-opus-4-6"


_cached_anthropic_client = None
_anthropic_client_checked = False
_warned_no_transport = False


def _warn_no_transport_once(lane: str) -> None:
    """Warn once when narrative LLM gates have no usable Claude transport."""
    global _warned_no_transport
    if _warned_no_transport:
        return
    _warned_no_transport = True
    logger.warning(
        "LLM gates degraded: no usable Claude transport (lane=%s) — narrative LLM gates will skip",
        lane,
    )


def get_anthropic_client():
    """
    Get Anthropic client for LLM-based gates. Returns None if unavailable.
    Cached after first call — safe to call from hot paths.
    """
    global _cached_anthropic_client, _anthropic_client_checked
    if _anthropic_client_checked:
        if _cached_anthropic_client is None:
            _warn_no_transport_once("sdk")
        return _cached_anthropic_client
    _anthropic_client_checked = True
    try:
        import anthropic
        api_key = os.environ.get("ANTHROPIC_API_KEY")
        if not api_key:
            _warn_no_transport_once("sdk")
            return None
        _cached_anthropic_client = anthropic.Anthropic(api_key=api_key)
        return _cached_anthropic_client
    except ImportError:
        _warn_no_transport_once("sdk")
        return None


def call_anthropic(client, model: str, prompt: str, max_tokens: int = 200) -> Optional[str]:
    """
    Call Anthropic API and return stripped response text, or None on failure.
    Centralizes the client.messages.create() pattern used across all hooks.
    """
    if client is None:
        return None
    try:
        resp = client.messages.create(
            model=model,
            max_tokens=max_tokens,
            messages=[{"role": "user", "content": prompt}],
        )
        return resp.content[0].text.strip() if resp.content else None
    except Exception:
        return None


def llm_gate_call(prompt: str, *, model: str, max_tokens: int = 200) -> Optional[str]:
    """
    Text-only Claude seam for narrative LLM gates.

    Vision callers must keep using their vision-aware paths.
    """
    from recoil.core.claude_cli import ClaudeCliError, claude_cli_call, claude_transport

    lane = claude_transport()
    if lane == "cli":
        try:
            return claude_cli_call(prompt, model=model).strip()
        except ClaudeCliError:
            _warn_no_transport_once(lane)
            return None

    client = get_anthropic_client()
    if client is None:
        _warn_no_transport_once(lane)
        return None
    return call_anthropic(client, model, prompt, max_tokens=max_tokens)


def parse_llm_field(result: str, field: str, expected: Optional[list] = None) -> Optional[str]:
    """
    Parse a structured field from an LLM response.

    LLM gates use prompts that request output like:
        SHIFTS: YES
        CAUSE: some text here
        VERDICT: GENUINE

    This function extracts the value after "FIELD:" on the same line,
    strips whitespace, and optionally matches against expected values.

    Args:
        result: Full LLM response text
        field: Field name to extract (e.g., "SHIFTS", "VERDICT")
        expected: If provided, return the first matching expected value
                  found in the extracted text (case-insensitive).
                  If no match, returns None.
                  If not provided, returns the raw extracted text.

    Returns:
        The extracted value, a matched expected value, or None if
        the field isn't found or doesn't match any expected value.
    """
    if f"{field}:" not in result:
        return None

    # Extract text after "FIELD:" up to the next newline
    after_field = result.split(f"{field}:")[-1].split("\n")[0].strip()

    if expected is None:
        return after_field if after_field else None

    # Match against expected values (case-insensitive)
    after_upper = after_field.upper()
    for value in expected:
        if value.upper() in after_upper:
            return value.upper()

    return None


def extract_script_content(content: str) -> str:
    """
    Extract the script/fountain content from an episode file.
    Canonical implementation — import this instead of defining locally.

    Extracts ```fountain blocks or ## SCRIPT sections.
    Strips footer metadata after '---' separator.
    Falls back to full content if no markers found.
    """
    # Try ```fountain block first
    if "```fountain" in content:
        parts = content.split("```fountain")
        if len(parts) > 1:
            script = parts[1].split("```")[0]
            return script.strip()

    # Try ## SCRIPT section
    if "## SCRIPT" in content:
        script = content.split("## SCRIPT", 1)[1]
        # Strip any trailing sections
        for marker in ["## NOTES", "## METADATA", "## ---"]:
            if marker in script:
                script = script.split(marker)[0]
        return script.strip()

    # Fallback: strip footer metadata after --- separator
    if "\n---\n" in content:
        content = content.rsplit("\n---\n", 1)[0]

    return content.strip()


if __name__ == '__main__':
    # Print all constants when run directly (useful for debugging)
    print("Engine Constants (from CONSTANTS.md)")
    print("=" * 50)
    print(f"Source: {_CONSTANTS_FILE}")
    print()

    print("Episode Constraints:")
    print(f"  Word count: {WORD_COUNT_MIN}-{WORD_COUNT_MAX}")
    print(f"  Validation range: {VALIDATION_WORD_MIN}-{VALIDATION_WORD_MAX}")
    print(f"  Dialogue max: {DIALOGUE_MAX_PERCENT}%")
    print(f"  Max exchanges: {MAX_EXCHANGES}")
    print()

    print("Pattern Distribution:")
    print(f"  Hook silent: {HOOK_SILENT_PERCENT}%")
    print(f"  Cliffhanger mid-action: {CLIFFHANGER_MIDACTION_PERCENT}%")
    print(f"  Max consecutive same type: {MAX_CONSECUTIVE_SAME_TYPE}")
    print()

    print("Batch Structure:")
    print(f"  Generation batch size: {GENERATION_BATCH_SIZE}")
    print(f"  Treatment batch size: {TREATMENT_BATCH_SIZE}")
    print(f"  Total episodes: {TOTAL_EPISODES}")
    print(f"  Total generation batches: {TOTAL_BATCHES}")
    print()

    print("Treatment Constraints:")
    print(f"  Total words: {TREATMENT_TOTAL_WORDS_MIN}-{TREATMENT_TOTAL_WORDS_MAX}")
    print(f"  Avg per episode: {TREATMENT_AVG_PER_EPISODE_MIN}-{TREATMENT_AVG_PER_EPISODE_MAX}")
