#!/usr/bin/python3
"""
Batch Validation Script (Format-Aware)

Validates that all episodes in a batch meet format-specific constraints.
Detects format from project_config.json and dispatches to the correct
format validator. Falls back to Kill Box (V12) if no config is found.

Usage: python3 validate_batch.py <project_path> <batch_number>
Example: python3 validate_batch.py ./singularity 1

Returns:
- Exit code 0: All episodes pass
- Exit code 1: One or more episodes fail (details printed)
- Exit code 2: Missing episodes or other error
"""

import json
import sys
import re
import importlib.util
from pathlib import Path

# Add engine tools to path for imports
_SCRIPT_DIR = Path(__file__).parent.resolve()
_ENGINE_ROOT = _SCRIPT_DIR.parent.parent
_ENGINE_TOOLS = _ENGINE_ROOT / 'tools'
if _ENGINE_TOOLS.exists():
    sys.path.insert(0, str(_ENGINE_TOOLS))


# =============================================================================
# FORMAT DETECTION & DISPATCH
# =============================================================================

def _detect_format(project_path: Path) -> str:
    """Detect format from project_config.json. Defaults to 'kill_box'."""
    config_file = project_path / 'project_config.json'
    if not config_file.exists():
        return 'kill_box'
    try:
        config = json.loads(config_file.read_text(encoding='utf-8'))
        return config.get('format', 'kill_box')
    except (json.JSONDecodeError, OSError):
        return 'kill_box'


def _load_format_validator(format_name: str):
    """
    Dynamically import the format-specific validator.
    Returns the module or None if not found.
    """
    validator_path = _ENGINE_ROOT / 'formats' / format_name / 'validate.py'
    if not validator_path.exists():
        return None

    spec = importlib.util.spec_from_file_location(
        f"formats.{format_name}.validate",
        validator_path
    )
    module = importlib.util.module_from_spec(spec)
    try:
        spec.loader.exec_module(module)
        if hasattr(module, 'validate_episode') and hasattr(module, 'validate_batch'):
            return module
    except Exception as e:
        print(f"WARNING: Failed to load {format_name} validator: {e}")
    return None


def _try_format_dispatch(project_path: Path, batch_num: int, ep_start: int,
                         ep_end: int, batch_size: int):
    """
    Attempt format-aware dispatch. Returns True if handled, False to fall
    through to legacy Kill Box validation.
    """
    format_name = _detect_format(project_path)

    # For kill_box, fall through to the legacy path (preserves all existing
    # behavior including LLM value-turn checks, meta-reference checks, etc.)
    if format_name == 'kill_box':
        return False

    validator = _load_format_validator(format_name)
    if validator is None:
        print(f"WARNING: No validator for format '{format_name}', falling back to kill_box")
        return False

    # --- Format-specific validation path ---
    episodes_dir = project_path / "episodes"
    if not episodes_dir.exists():
        print(f"Error: Episodes directory does not exist: {episodes_dir}")
        sys.exit(2)

    print(f"\n{'=' * 60}")
    print(f"VALIDATING BATCH {batch_num} (Episodes {ep_start}-{ep_end})")
    print(f"Project: {project_path.name}  |  Format: {format_name}")
    print(f"{'=' * 60}\n")

    episode_paths = []
    for ep_num in range(ep_start, ep_end + 1):
        ep_file = episodes_dir / f"ep_{ep_num:03d}.md"
        episode_paths.append(str(ep_file))

    result = validator.validate_batch(episode_paths)

    all_passed = result.get('valid', False)

    for ep_result in result.get('episode_results', []):
        ep_file = ep_result.get('file', 'unknown')
        ep_name = Path(ep_file).name
        if ep_result.get('valid', False):
            metrics = ep_result.get('metrics', {})
            wc = metrics.get('total_word_count', metrics.get('word_count', '?'))
            print(f"  [PASS] {ep_name}: {wc} words")
        else:
            print(f"  [FAIL] {ep_name}:")
            for err in ep_result.get('errors', []):
                print(f"         - {err}")
            for warn in ep_result.get('warnings', []):
                print(f"         ~ {warn}")

    print(f"\n{'=' * 60}")
    if all_passed:
        print(f"BATCH {batch_num} VALIDATED: All episodes pass {format_name} criteria")
        print(f"{'=' * 60}\n")
        sys.exit(0)
    else:
        failed = [r for r in result.get('episode_results', []) if not r.get('valid')]
        print(f"BATCH {batch_num} FAILED: {len(failed)} episode(s) need revision")
        print(f"\nAfter fixing, re-run: python3 validate_batch.py {project_path.name} {batch_num}")
        print(f"{'=' * 60}\n")
        sys.exit(1)

    return True  # Handled

# Import constants and shared counting functions from shared module (reads from CONSTANTS.md)
try:
    from engine_constants import (
        WORD_COUNT_MIN, WORD_COUNT_MAX,
        DIALOGUE_MAX_PERCENT, MAX_EXCHANGES,
        GENERATION_BATCH_SIZE,
        ANTHROPIC_HAIKU,
        get_anthropic_client,
        call_anthropic,
        parse_llm_field,
        extract_script_content,
        count_words as _shared_count_words,
        parse_dialogue_blocks,
        count_dialogue_words as _shared_count_dialogue_words,
        count_exchanges as _shared_count_exchanges,
        calculate_dialogue_percent,
    )
    EXCHANGE_MAX = MAX_EXCHANGES  # Alias for compatibility
    _USE_SHARED = True
except ImportError:
    # Fallback if engine_constants not available
    print("WARNING: Could not import engine_constants, using fallback values")
    WORD_COUNT_MIN = 450
    WORD_COUNT_MAX = 500
    DIALOGUE_MAX_PERCENT = 40
    EXCHANGE_MAX = 8
    GENERATION_BATCH_SIZE = 5
    get_anthropic_client = lambda: None
    call_anthropic = lambda client, model, prompt, max_tokens=200: None
    parse_llm_field = lambda result, field, expected=None: None
    extract_script_content = lambda content: content
    _USE_SHARED = False

def extract_script_content(content):
    """
    Return all content for word counting.

    Simple approach: count ALL words in the file.
    No exclusions - what you write is what gets counted.
    """
    return content


def count_words(script_text):
    """Count words in extracted script content. Delegates to shared implementation."""
    if not script_text:
        return 0
    if _USE_SHARED:
        return _shared_count_words(script_text)
    print("  WARNING: Using fallback word counter (text.split) — counts may differ from canonical engine_constants.count_words()")
    return len(script_text.split())


def count_dialogue_words(script_text):
    """
    Count words that are dialogue (lines following CHARACTER NAME).
    Delegates to shared implementation from engine_constants.
    """
    if _USE_SHARED:
        blocks = parse_dialogue_blocks(script_text)
        return _shared_count_dialogue_words(blocks)

    # Fallback: original implementation
    lines = script_text.split('\n')
    dialogue_words = 0
    in_dialogue = False

    for line in lines:
        stripped = line.strip()

        if stripped and stripped.isupper() and len(stripped) < 30:
            if not stripped.startswith('.') and ':' not in stripped:
                skip_words = ['ECU', 'CU', 'MCU', 'MS', 'WS', 'POV', 'SFX', 'VFX',
                             'INSERT', 'CONTINUOUS', 'LATER', 'PULL BACK', 'CLICK']
                if stripped not in skip_words and not stripped.startswith('PULL'):
                    in_dialogue = True
                    continue

        if in_dialogue:
            if stripped.startswith('(') and stripped.endswith(')'):
                continue
            if not stripped:
                in_dialogue = False
                continue
            if stripped.startswith('INT.') or stripped.startswith('EXT.'):
                in_dialogue = False
                continue
            dialogue_words += len(stripped.split())
            continue

        if stripped.startswith('(') and stripped.endswith(')'):
            continue

    return dialogue_words


def count_exchanges(script_text):
    """Count dialogue exchanges (character name followed by dialogue).
    Delegates to shared implementation from engine_constants."""
    if _USE_SHARED:
        blocks = parse_dialogue_blocks(script_text)
        return _shared_count_exchanges(blocks)

    # Fallback: original implementation
    lines = script_text.split('\n')
    exchanges = 0

    for line in lines:
        stripped = line.strip()

        if stripped and stripped.isupper() and len(stripped) < 30:
            if not stripped.startswith('.') and ':' not in stripped:
                skip_words = ['ECU', 'CU', 'MCU', 'MS', 'WS', 'POV', 'SFX', 'VFX',
                             'INSERT', 'CONTINUOUS', 'LATER', 'PULL BACK', 'CLICK']
                if stripped not in skip_words and not stripped.startswith('PULL'):
                    exchanges += 1

    return exchanges


def extract_prose_section(content, section_name):
    """
    Extract the prose content from a Kill Box section.

    Excludes:
    - The section header itself
    - Scene headings (INT./EXT.)
    - Character names (ALL CAPS)
    - Dialogue following character names

    Returns only action/description lines.
    """
    # Build pattern to match section (handles variations in spacing)
    pattern = rf'#\s*\[[\d:]+\s*-\s*[\d:]+\]\s*{section_name}.*?\n(.*?)(?=#\s*\[[\d:]+|---|\Z)'
    match = re.search(pattern, content, re.DOTALL | re.IGNORECASE)

    if not match:
        return ""

    section_content = match.group(1)
    prose_lines = []
    skip_next = False  # Skip dialogue lines after character names

    for line in section_content.split('\n'):
        stripped = line.strip()

        # Skip empty lines
        if not stripped:
            skip_next = False
            continue

        # Skip scene headings
        if stripped.startswith('INT.') or stripped.startswith('EXT.'):
            continue

        # Check if this is a character name (ALL CAPS, triggers dialogue skip)
        if stripped and stripped.isupper() and len(stripped) < 30:
            if not stripped.startswith('.') and ':' not in stripped:
                skip_words = ['ECU', 'CU', 'MCU', 'MS', 'WS', 'POV', 'SFX', 'VFX',
                             'INSERT', 'CONTINUOUS', 'LATER', 'PULL BACK', 'CLICK']
                if stripped not in skip_words and not stripped.startswith('PULL'):
                    skip_next = True
                    continue

        # While skipping dialogue, check for exit conditions
        if skip_next:
            # Parenthetical in dialogue — skip but stay in dialogue
            if stripped.startswith('(') and stripped.endswith(')'):
                continue
            # Non-empty line is still dialogue — skip it
            if stripped:
                continue
            # Empty line ends dialogue block
            skip_next = False
            continue

        # Skip parentheticals outside dialogue
        if stripped.startswith('(') and stripped.endswith(')'):
            continue

        # This is prose/action - keep it
        prose_lines.append(stripped)

    return '\n'.join(prose_lines)


# ---------------------------------------------------------------------------
# LLM-based Value Turn Verification (Scaffolding Gate G1)
# Uses Haiku for per-episode binary check: does the core value shift polarity?
# ---------------------------------------------------------------------------

_VALUE_TURN_PROMPT = """You are evaluating a 450-word microdrama episode for structural integrity.

RULE: Every episode must shift a value polarity. The emotional or dramatic charge at the end must differ from the beginning. An episode that describes a situation without turning it is not a story event.

Here is the episode:

<episode>
{episode_text}
</episode>

Answer these two questions:
1. What is the core value at stake in this episode? (one phrase: e.g., "safety vs danger", "trust vs betrayal", "hope vs despair")
2. Does this value shift polarity from the beginning to the end? (YES or NO)

Output EXACTLY in this format:
VALUE: [your answer]
SHIFTS: [YES or NO]
REASON: [one sentence explaining why]"""


def check_value_turn(content):
    """
    Check if an episode shifts a core value polarity (McKee).

    Uses Haiku for binary classification. If API is unavailable,
    returns pass with a warning (graceful degradation).

    Returns (passed, detail_string)
    """
    client = get_anthropic_client()
    if client is None:
        return True, "SKIPPED (no ANTHROPIC_API_KEY — install anthropic and set key for Value Turn checks)"

    script = extract_script_content(content)
    prompt = _VALUE_TURN_PROMPT.format(episode_text=script)

    try:
        response = client.messages.create(
            model=ANTHROPIC_HAIKU,
            max_tokens=200,
            messages=[{"role": "user", "content": prompt}],
        )
        result = response.content[0].text.strip()

        shifts = parse_llm_field(result, "SHIFTS", ["YES", "NO"])

        if shifts is None:
            return True, f"INCONCLUSIVE — {result[:100]}"
        elif shifts == "YES":
            value = parse_llm_field(result, "VALUE") or ""
            return True, f"Value turn confirmed: {value}"
        else:
            reason = parse_llm_field(result, "REASON") or ""
            return False, f"No value turn detected. {reason}"

    except Exception as e:
        return True, f"SKIPPED (API error: {e})"


def check_meta_references(content):
    """
    Check that prose doesn't contain episode number references.

    Meta-references like "Episode 10" or "kept since Episode 1" are production
    metadata. The audience doesn't track episode numbers - they experience
    events. References should be to WHAT HAPPENED, not episode numbers.

    Returns (passed, issues) where passed is True if no issues found.
    """
    issues = []

    # Kill Box sections to check (prose only, not metadata)
    sections = ['THE HOOK', 'THE SETUP', 'THE ESCALATION', 'THE TURN', 'THE CLIFFHANGER']

    for section_name in sections:
        prose = extract_prose_section(content, section_name)
        if prose:
            # Find "Episode N" patterns (case-insensitive)
            matches = re.findall(r'\bEpisode\s+\d+\b', prose, re.IGNORECASE)
            if matches:
                issues.append(f"Meta-reference in {section_name}: {matches}")

    # Also check for other meta-reference patterns
    full_prose = '\n'.join(extract_prose_section(content, s) for s in sections)

    # Check for "N episodes" patterns that reference episode count
    episode_count_matches = re.findall(r'\b\d+\s+episodes?\b', full_prose, re.IGNORECASE)
    if episode_count_matches:
        issues.append(f"Episode count reference: {episode_count_matches}")

    return len(issues) == 0, issues


def check_hook_silent(content):
    """
    Check if THE HOOK section has dialogue.

    V12 allows 20% of hooks to have dialogue (pattern interrupts).
    Returns info for tracking, not strict pass/fail.
    """
    # Find THE HOOK section
    hook_match = re.search(r'#\s*\[[\d:]+\s*-\s*[\d:]+\]\s*THE HOOK.*?\n(.*?)(?=#\s*\[[\d:]+|$)',
                          content, re.DOTALL | re.IGNORECASE)
    if not hook_match:
        return True, "No HOOK section found"

    hook_content = hook_match.group(1)

    # Check for character names (ALL CAPS) which indicate dialogue
    for line in hook_content.split('\n'):
        stripped = line.strip()
        if stripped and stripped.isupper() and len(stripped) < 30:
            if not stripped.startswith('INT.') and not stripped.startswith('EXT.'):
                if not stripped.startswith('.') and ':' not in stripped:
                    skip_words = ['ECU', 'CU', 'MCU', 'MS', 'WS', 'POV', 'SFX', 'VFX',
                                 'INSERT', 'CONTINUOUS', 'PULL BACK']
                    if stripped not in skip_words and not stripped.startswith('PULL'):
                        return False, f"Dialogue hook: {stripped}"

    return True, "Hook is silent"


def check_formatting(content):
    """
    Check for correct V12 formatting structure.

    Required:
    - Episode header: [[EPISODE X: TITLE]]
    - Kill Box sections with timestamps
    """
    issues = []

    # Check for episode header
    if not re.search(r'\[\[EPISODE\s+\d+', content, re.IGNORECASE):
        issues.append("Missing episode header [[EPISODE X: TITLE]]")

    # Check for Kill Box timing blocks
    required_sections = [
        (r'#\s*\[00:00\s*-\s*00:05\].*THE HOOK', 'THE HOOK'),
        (r'#\s*\[00:05\s*-\s*00:15\].*THE SETUP', 'THE SETUP'),
        (r'#\s*\[00:15\s*-\s*00:40\].*THE ESCALATION', 'THE ESCALATION'),
        (r'#\s*\[00:40\s*-\s*00:70\].*THE TURN', 'THE TURN'),
        (r'#\s*\[00:70\s*-\s*00:90\].*THE CLIFFHANGER', 'THE CLIFFHANGER'),
    ]

    missing = []
    for pattern, name in required_sections:
        if not re.search(pattern, content, re.IGNORECASE):
            missing.append(name)

    if missing:
        issues.append(f"Missing Kill Box sections: {', '.join(missing)}")

    return len(issues) == 0, issues


def validate_episode(filepath, value_turn_result=None):
    """Validate a single episode file. Returns (passed, issues, stats).

    Args:
        filepath: Path to episode file
        value_turn_result: Optional pre-computed (passed, detail) from parallel check
    """
    issues = []

    if not filepath.exists():
        return False, [f"File not found: {filepath}"], None

    content = filepath.read_text()

    # Check formatting first
    format_passed, format_issues = check_formatting(content)
    issues.extend(format_issues)

    # Check for meta-references (episode numbers in prose)
    meta_passed, meta_issues = check_meta_references(content)
    issues.extend(meta_issues)

    # Extract script content (excludes metadata)
    script = extract_script_content(content)

    # Count words in script content only
    total_words = count_words(script)

    # Count dialogue
    dialogue_words = count_dialogue_words(script)
    dialogue_percent = (dialogue_words / total_words * 100) if total_words > 0 else 0

    # Count exchanges
    exchanges = count_exchanges(script)

    # Check hook
    hook_silent, hook_msg = check_hook_silent(content)

    # Validate against constraints
    if total_words < WORD_COUNT_MIN:
        issues.append(f"Word count too LOW: {total_words} (min: {WORD_COUNT_MIN})")
    elif total_words > WORD_COUNT_MAX:
        issues.append(f"Word count too HIGH: {total_words} (max: {WORD_COUNT_MAX})")

    if dialogue_percent > DIALOGUE_MAX_PERCENT:
        issues.append(f"Dialogue too HIGH: {dialogue_percent:.1f}% (max: {DIALOGUE_MAX_PERCENT}%)")

    if exchanges > EXCHANGE_MAX:
        issues.append(f"Exchanges too HIGH: {exchanges} (max: {EXCHANGE_MAX})")

    # Value Turn Verification (Scaffolding Gate G1)
    if value_turn_result is not None:
        value_turn_passed, value_turn_detail = value_turn_result
    else:
        value_turn_passed, value_turn_detail = check_value_turn(content)
    if not value_turn_passed:
        issues.append(f"VALUE TURN: {value_turn_detail}")

    passed = len(issues) == 0

    stats = {
        "words": total_words,
        "dialogue_pct": round(dialogue_percent, 1),
        "exchanges": exchanges,
        "hook_silent": hook_silent,
        "format_valid": format_passed,
        "meta_valid": meta_passed,
        "value_turn": value_turn_passed,
        "value_turn_detail": value_turn_detail,
    }

    return passed, issues, stats


def main():
    if len(sys.argv) < 3:
        print("Usage: python3 validate_batch.py <project_path> <batch_number>")
        print("Example: python3 validate_batch.py ./singularity 1")
        sys.exit(2)

    project_path = Path(sys.argv[1]).resolve()
    batch_num = int(sys.argv[2])

    if not project_path.exists():
        print(f"Error: Project path does not exist: {project_path}")
        sys.exit(2)

    episodes_dir = project_path / "episodes"
    if not episodes_dir.exists():
        print(f"Error: Episodes directory does not exist: {episodes_dir}")
        sys.exit(2)

    # Calculate episode range (uses GENERATION_BATCH_SIZE from CONSTANTS.md)
    ep_start = (batch_num - 1) * GENERATION_BATCH_SIZE + 1
    ep_end = batch_num * GENERATION_BATCH_SIZE

    # --- Format-aware dispatch ---
    # If the project uses a non-kill_box format, dispatch to its validator
    # and exit. Otherwise fall through to legacy Kill Box validation.
    _try_format_dispatch(project_path, batch_num, ep_start, ep_end, GENERATION_BATCH_SIZE)

    print(f"\n{'='*60}")
    print(f"VALIDATING BATCH {batch_num} (Episodes {ep_start}-{ep_end})")
    print(f"Project: {project_path.name}")
    print(f"{'='*60}\n")

    all_passed = True
    results = []

    # Pre-fetch value turn results in parallel (all Haiku calls at once)
    from concurrent.futures import ThreadPoolExecutor
    ep_files = {}
    for ep_num in range(ep_start, ep_end + 1):
        ep_file = episodes_dir / f"ep_{ep_num:03d}.md"
        if ep_file.exists():
            ep_files[ep_num] = ep_file

    vt_results = {}
    if ep_files and get_anthropic_client() is not None:
        def _vt_check(ep_num):
            content = ep_files[ep_num].read_text()
            return ep_num, check_value_turn(content)

        with ThreadPoolExecutor(max_workers=len(ep_files)) as executor:
            for ep_num, result in executor.map(_vt_check, ep_files.keys()):
                vt_results[ep_num] = result

    for ep_num in range(ep_start, ep_end + 1):
        ep_file = episodes_dir / f"ep_{ep_num:03d}.md"

        if not ep_file.exists():
            print(f"  [MISSING] Episode {ep_num}: {ep_file.name}")
            all_passed = False
            results.append((ep_num, False, ["File not found"], None))
            continue

        passed, issues, stats = validate_episode(ep_file, value_turn_result=vt_results.get(ep_num))
        results.append((ep_num, passed, issues, stats))

        if passed:
            hook_status = "silent" if stats['hook_silent'] else "DIALOGUE"
            vt_status = stats.get('value_turn_detail', '')
            vt_short = vt_status[:40] if vt_status else ''
            print(f"  [PASS] Episode {ep_num}: {stats['words']} words, "
                  f"{stats['dialogue_pct']}% dialogue, {stats['exchanges']} exchanges, "
                  f"hook: {hook_status}")
            if vt_short and not vt_short.startswith("SKIPPED"):
                print(f"         Value turn: {vt_short}")
        else:
            print(f"  [FAIL] Episode {ep_num}:")
            for issue in issues:
                print(f"         - {issue}")
            all_passed = False

    print(f"\n{'='*60}")

    if all_passed:
        print(f"BATCH {batch_num} VALIDATED: All episodes pass V12 criteria")
        print(f"{'='*60}\n")
        sys.exit(0)
    else:
        failed = [r for r in results if not r[1]]
        print(f"BATCH {batch_num} FAILED: {len(failed)} episode(s) need revision")
        print(f"\nFIX THESE EPISODES:")
        for ep_num, passed, issues, stats in failed:
            print(f"  - Episode {ep_num}:")
            for issue in issues:
                print(f"      {issue}")
        print(f"\nAfter fixing, re-run: python3 validate_batch.py {project_path.name} {batch_num}")
        print(f"{'='*60}\n")
        sys.exit(1)


if __name__ == "__main__":
    main()
