#!/usr/bin/python3
"""
Validate Episode Arc

Pre-generation validation for episode_arc.md to ensure:
- Thread coherence (plant before payoff, at least 1 advance each)
- Minimum 6 threads
- Hook ratio ~80% silent / 20% dialogue
- Cliffhanger ratio ~80% mid-action / 20% aftermath
- No 4+ consecutive same type (max 3 consecutive allowed per CONSTANTS.md; hooks or cliffhangers)
- Structural beats present at eps 15, 30, 45, 60
- All 60 episodes present with required fields

Usage: python3 validate_episode_arc.py <episode_arc_path>
Example: python3 validate_episode_arc.py ./leviathan/bible/episode_arc.md

Exit codes:
- 0 = Passed (all checks OK)
- 1 = Errors found (must fix before generation)
- 2 = Warnings only (can proceed but review recommended)
"""

import sys
import re
from pathlib import Path
from collections import defaultdict


def parse_episode_arc(content):
    """Parse episode_arc.md and extract relevant data."""
    result = {
        "threads": [],
        "episodes": {},
        "has_thread_index": False,
        "structural_beats": {15: False, 30: False, 45: False, 60: False}
    }

    # Check for Thread Index section
    if "## THREAD INDEX" in content or "## Thread Index" in content:
        result["has_thread_index"] = True

        # Parse thread table
        thread_section = re.search(
            r'##\s*THREAD\s*INDEX.*?(?=##|\Z)',
            content,
            re.IGNORECASE | re.DOTALL
        )
        if thread_section:
            thread_text = thread_section.group(0)
            # Find table rows (skip header)
            table_rows = re.findall(r'\|([^|]+)\|([^|]+)\|([^|]+)\|([^|]+)\|([^|]+)\|([^|]+)\|', thread_text)
            for row in table_rows[1:]:  # Skip header row
                if row[0].strip() and not row[0].strip().startswith('-'):
                    thread = {
                        "name": row[0].strip(),
                        "type": row[1].strip(),
                        "plant_ep": parse_episode_number(row[2].strip()),
                        "advances": parse_episode_list(row[3].strip()),
                        "payoff_ep": parse_episode_number(row[4].strip()),
                        "description": row[5].strip()
                    }
                    if thread["name"] and not thread["name"].startswith('{'):
                        result["threads"].append(thread)

    # Parse episode tables
    # Look for sequence headers and their tables
    sequence_pattern = r'###\s*SEQUENCE\s*\d+.*?(?=###|\Z)'
    sequences = re.findall(sequence_pattern, content, re.IGNORECASE | re.DOTALL)

    for seq in sequences:
        # Find episode rows in table
        # Format: | Ep | Title | One-Line | Must Contain | Hook | Cliffhanger | Thematic Beat |
        # or old format: | Ep | Title | One-Line | Must Contain | Cliffhanger Type | Thematic Beat |
        ep_rows = re.findall(
            r'\|\s*(\d+)\s*\|([^|]*)\|([^|]*)\|([^|]*)\|([^|]*)\|([^|]*)\|([^|]*)?\|?',
            seq
        )
        for row in ep_rows:
            ep_num = int(row[0])
            result["episodes"][ep_num] = {
                "number": ep_num,
                "title": row[1].strip(),
                "one_line": row[2].strip(),
                "must_contain": row[3].strip(),
                "hook_or_cliffhanger": row[4].strip(),  # May be hook or cliffhanger depending on format
                "cliffhanger_or_thematic": row[5].strip(),
                "thematic": row[6].strip() if len(row) > 6 else ""
            }

    # Check for structural beats (look for locked emoji or specific keywords)
    structural_keywords = {
        15: ["POINT OF NO RETURN", "Plot Point 1", "Threshold"],
        30: ["MIDPOINT", "Midpoint"],
        45: ["ALL IS LOST", "Low Point", "Dark Night"],
        60: ["RESOLUTION", "Resolution", "Climax"]
    }

    for ep_num, keywords in structural_keywords.items():
        for keyword in keywords:
            if keyword.lower() in content.lower():
                # Check if it's associated with the right episode
                pattern = rf'(ep|episode)?\s*{ep_num}.*?{keyword}|{keyword}.*?(ep|episode)?\s*{ep_num}'
                if re.search(pattern, content, re.IGNORECASE):
                    result["structural_beats"][ep_num] = True
                    break

    return result


def parse_episode_number(text):
    """Extract episode number from text like 'Ep 5' or '5'."""
    match = re.search(r'(\d+)', text)
    return int(match.group(1)) if match else None


def parse_episode_list(text):
    """Extract list of episode numbers from text like 'Ep 10, 25' or '10, 25'."""
    numbers = re.findall(r'(\d+)', text)
    return [int(n) for n in numbers]


def check_hook_types(episodes):
    """
    Check hook type distribution and consecutive patterns.
    Returns (issues, warnings).
    """
    issues = []
    warnings = []

    # Collect hook types
    hook_types = []
    for ep_num in sorted(episodes.keys()):
        ep = episodes[ep_num]
        # Try to determine hook type from the data
        hook_field = ep.get("hook_or_cliffhanger", "").upper()

        if "SILENT" in hook_field:
            hook_types.append(("SILENT", ep_num))
        elif "DIALOGUE" in hook_field:
            hook_types.append(("DIALOGUE", ep_num))
        # If no explicit hook field, skip

    if not hook_types:
        warnings.append("No explicit Hook column found - consider adding to episode_arc.md")
        return issues, warnings

    # Check ratio
    silent_count = sum(1 for h, _ in hook_types if h == "SILENT")
    dialogue_count = sum(1 for h, _ in hook_types if h == "DIALOGUE")
    total = silent_count + dialogue_count

    if total > 0:
        silent_ratio = silent_count / total
        if silent_ratio < 0.7:
            warnings.append(f"Hook ratio: {silent_ratio:.0%} silent - target is ~80% silent")
        elif silent_ratio > 0.9:
            warnings.append(f"Hook ratio: {silent_ratio:.0%} silent - consider adding more dialogue hooks for variety")

    # Check consecutive patterns — 4+ is violation (max 3 allowed per CONSTANTS.md)
    consecutive = 1
    prev_type = None
    for hook_type, ep_num in hook_types:
        if hook_type == prev_type:
            consecutive += 1
            if consecutive >= 4:
                issues.append(f"4+ consecutive {hook_type} hooks ending at Episode {ep_num}")
        else:
            consecutive = 1
        prev_type = hook_type

    return issues, warnings


def check_cliffhanger_types(episodes):
    """
    Check cliffhanger type distribution and consecutive patterns.
    Returns (issues, warnings).
    """
    issues = []
    warnings = []

    # Collect cliffhanger types
    cliffhanger_types = []
    for ep_num in sorted(episodes.keys()):
        ep = episodes[ep_num]
        # Check both possible columns
        cliff_field = ep.get("cliffhanger_or_thematic", "").upper()
        if "MID-ACTION" in cliff_field or "MIDACTION" in cliff_field:
            cliffhanger_types.append(("MID-ACTION", ep_num))
        elif "AFTERMATH" in cliff_field:
            cliffhanger_types.append(("AFTERMATH", ep_num))
        else:
            # Try the other field
            cliff_field = ep.get("hook_or_cliffhanger", "").upper()
            if "MID-ACTION" in cliff_field or "MIDACTION" in cliff_field:
                cliffhanger_types.append(("MID-ACTION", ep_num))
            elif "AFTERMATH" in cliff_field:
                cliffhanger_types.append(("AFTERMATH", ep_num))

    if not cliffhanger_types:
        warnings.append("No explicit cliffhanger types found in episode tables")
        return issues, warnings

    # Check ratio
    midaction_count = sum(1 for c, _ in cliffhanger_types if c == "MID-ACTION")
    aftermath_count = sum(1 for c, _ in cliffhanger_types if c == "AFTERMATH")
    total = midaction_count + aftermath_count

    if total > 0:
        midaction_ratio = midaction_count / total
        if midaction_ratio < 0.7:
            warnings.append(f"Cliffhanger ratio: {midaction_ratio:.0%} mid-action - target is ~80%")
        elif midaction_ratio > 0.9:
            warnings.append(f"Cliffhanger ratio: {midaction_ratio:.0%} mid-action - consider adding more aftermath for variety")

    # Check consecutive patterns — 4+ is violation (max 3 allowed per CONSTANTS.md)
    consecutive = 1
    prev_type = None
    for cliff_type, ep_num in cliffhanger_types:
        if cliff_type == prev_type:
            consecutive += 1
            if consecutive >= 4:
                issues.append(f"4+ consecutive {cliff_type} cliffhangers ending at Episode {ep_num}")
        else:
            consecutive = 1
        prev_type = cliff_type

    return issues, warnings


def validate_threads(threads):
    """
    Validate thread coherence.
    Returns (issues, warnings).
    """
    issues = []
    warnings = []

    if len(threads) < 6:
        issues.append(f"Only {len(threads)} threads defined - minimum 6 required")

    for thread in threads:
        name = thread["name"]
        plant_ep = thread["plant_ep"]
        advances = thread["advances"]
        payoff_ep = thread["payoff_ep"]

        # Check plant before payoff
        if plant_ep and payoff_ep and plant_ep >= payoff_ep:
            issues.append(f"Thread '{name}': Plant (Ep {plant_ep}) must be before Payoff (Ep {payoff_ep})")

        # Check at least one advance
        if not advances:
            warnings.append(f"Thread '{name}': No advances defined - consider adding at least one")
        else:
            # Check advances are between plant and payoff
            for adv_ep in advances:
                if plant_ep and adv_ep <= plant_ep:
                    issues.append(f"Thread '{name}': Advance (Ep {adv_ep}) is before or at Plant (Ep {plant_ep})")
                if payoff_ep and adv_ep >= payoff_ep:
                    issues.append(f"Thread '{name}': Advance (Ep {adv_ep}) is at or after Payoff (Ep {payoff_ep})")

        # Check type is valid
        valid_types = ["object", "phrase", "image", "memory", "foreshadow"]
        if thread["type"] and thread["type"].lower() not in valid_types:
            warnings.append(f"Thread '{name}': Unrecognized type '{thread['type']}' - expected one of {valid_types}")

    return issues, warnings


def validate_structural_beats(structural_beats):
    """
    Validate structural beats are present at key episodes.
    Returns (issues, warnings).
    """
    issues = []
    warnings = []

    missing = [ep for ep, present in structural_beats.items() if not present]
    if missing:
        warnings.append(f"Structural beats not clearly marked for episodes: {missing}")

    return issues, warnings


def validate_episode_coverage(episodes):
    """
    Validate all 60 episodes are present with required fields.
    Returns (issues, warnings).
    """
    issues = []
    warnings = []

    # Check for all 60 episodes
    missing = []
    for ep_num in range(1, 61):
        if ep_num not in episodes:
            missing.append(ep_num)

    if missing:
        if len(missing) > 10:
            issues.append(f"Missing {len(missing)} episodes - first few: {missing[:5]}...")
        else:
            issues.append(f"Missing episodes: {missing}")

    # Check required fields
    empty_titles = []
    empty_one_lines = []
    for ep_num, ep in episodes.items():
        if not ep["title"] or ep["title"].startswith('{'):
            empty_titles.append(ep_num)
        if not ep["one_line"] or ep["one_line"].startswith('{'):
            empty_one_lines.append(ep_num)

    if empty_titles:
        warnings.append(f"Episodes with placeholder/empty titles: {empty_titles[:10]}{'...' if len(empty_titles) > 10 else ''}")
    if empty_one_lines:
        warnings.append(f"Episodes with placeholder/empty one-lines: {empty_one_lines[:10]}{'...' if len(empty_one_lines) > 10 else ''}")

    return issues, warnings


def main():
    if len(sys.argv) < 2:
        print("Usage: python3 validate_episode_arc.py <episode_arc_path>")
        print("Example: python3 validate_episode_arc.py ./leviathan/bible/episode_arc.md")
        sys.exit(2)

    arc_path = Path(sys.argv[1]).resolve()

    if not arc_path.exists():
        print(f"Error: File not found: {arc_path}")
        sys.exit(2)

    print(f"\n{'='*60}")
    print(f"EPISODE ARC VALIDATION")
    print(f"{'='*60}")
    print(f"File: {arc_path}")

    # Read and parse the file
    try:
        content = arc_path.read_text()
    except Exception as e:
        print(f"Error reading file: {e}")
        sys.exit(2)

    parsed = parse_episode_arc(content)

    all_issues = []
    all_warnings = []

    # Check for Thread Index
    print(f"\n--- Thread Index ---")
    if parsed["has_thread_index"]:
        print(f"Thread Index found: {len(parsed['threads'])} threads")
        issues, warnings = validate_threads(parsed["threads"])
        all_issues.extend(issues)
        all_warnings.extend(warnings)
    else:
        all_warnings.append("No Thread Index section found - consider adding one")
        print("No Thread Index section found")

    # Check episodes
    print(f"\n--- Episode Coverage ---")
    print(f"Episodes found: {len(parsed['episodes'])}")
    issues, warnings = validate_episode_coverage(parsed["episodes"])
    all_issues.extend(issues)
    all_warnings.extend(warnings)

    # Check hook types
    print(f"\n--- Hook Types ---")
    issues, warnings = check_hook_types(parsed["episodes"])
    all_issues.extend(issues)
    all_warnings.extend(warnings)

    # Check cliffhanger types
    print(f"\n--- Cliffhanger Types ---")
    issues, warnings = check_cliffhanger_types(parsed["episodes"])
    all_issues.extend(issues)
    all_warnings.extend(warnings)

    # Check structural beats
    print(f"\n--- Structural Beats ---")
    issues, warnings = validate_structural_beats(parsed["structural_beats"])
    all_issues.extend(issues)
    all_warnings.extend(warnings)
    for ep, present in parsed["structural_beats"].items():
        status = "Found" if present else "Not found"
        print(f"  Episode {ep}: {status}")

    # Report results
    print(f"\n{'='*60}")

    if all_issues:
        print(f"\n ERRORS ({len(all_issues)}) - Must fix before generation:")
        for issue in all_issues:
            print(f"  - {issue}")

    if all_warnings:
        print(f"\n WARNINGS ({len(all_warnings)}) - Review recommended:")
        for warning in all_warnings:
            print(f"  - {warning}")

    if all_issues:
        print(f"\n{'='*60}")
        print(f"RESULT: FAILED - Fix {len(all_issues)} error(s) before generation")
        print(f"{'='*60}\n")
        sys.exit(1)
    elif all_warnings:
        print(f"\n{'='*60}")
        print(f"RESULT: PASSED WITH WARNINGS - Review {len(all_warnings)} warning(s)")
        print(f"{'='*60}\n")
        sys.exit(2)
    else:
        print(f"\n{'='*60}")
        print(f"RESULT: PASSED - Episode arc is ready for generation")
        print(f"{'='*60}\n")
        sys.exit(0)


if __name__ == "__main__":
    main()
