#!/usr/bin/python3
"""
Generate Batch Summary

Extracts metadata from completed batch episodes to create batch_summary.json.
This summary is used by the orchestrator for cross-series verification.

Usage: python3 generate_batch_summary.py <project_path> <batch_number>
Example: python3 generate_batch_summary.py ./leviathan 3

Reads: ./[project]/episodes/ep_XXX.md (5 episodes)
Output: ./[project]/state/batch_N_summary.json
"""

import json
import re
import sys
from datetime import datetime
from pathlib import Path

# Add engine tools to path for imports
_SCRIPT_DIR = Path(__file__).parent.resolve()
sys.path.insert(0, str(_SCRIPT_DIR))
sys.path.insert(0, str(_SCRIPT_DIR.parent.parent))  # CLAUDE_PROJECTS, for recoil.*
from recoil.core.paths import ProjectPaths

try:
    from engine_constants import (
        WORD_COUNT_MIN, WORD_COUNT_MAX,
        DIALOGUE_MAX_PERCENT, MAX_EXCHANGES,
        GENERATION_BATCH_SIZE,
        count_words as _shared_count_words,
        parse_dialogue_blocks,
        count_dialogue_words as _shared_count_dialogue_words,
        count_exchanges as _shared_count_exchanges,
        is_character_cue as _shared_is_character_cue,
    )
    _USE_SHARED = True
except ImportError:
    print("WARNING: Could not import engine_constants, using fallback values")
    WORD_COUNT_MIN = 450
    WORD_COUNT_MAX = 500
    DIALOGUE_MAX_PERCENT = 40
    MAX_EXCHANGES = 8
    GENERATION_BATCH_SIZE = 5
    _USE_SHARED = False


def count_words(text: str) -> int:
    """Count words in text. Delegates to shared canonical implementation."""
    if not text:
        return 0
    if _USE_SHARED:
        return _shared_count_words(text)
    return len(text.split())


def count_dialogue_words(content: str) -> int:
    """Count words that are dialogue. Delegates to shared canonical implementation."""
    if _USE_SHARED:
        blocks = parse_dialogue_blocks(content)
        return _shared_count_dialogue_words(blocks)

    # Fallback: original implementation
    lines = content.split('\n')
    dialogue_words = 0
    in_dialogue = False

    for line in lines:
        stripped = line.strip()

        if stripped and stripped.isupper() and len(stripped) < 30:
            if not stripped.startswith('.') and ':' not in stripped:
                skip_words = ['ECU', 'CU', 'MCU', 'MS', 'WS', 'POV', 'SFX', 'VFX',
                             'INSERT', 'CONTINUOUS', 'LATER', 'PULL BACK', 'CLICK']
                if stripped not in skip_words and not stripped.startswith('PULL'):
                    in_dialogue = True
                    continue

        if in_dialogue:
            if stripped.startswith('(') and stripped.endswith(')'):
                continue
            if not stripped:
                in_dialogue = False
                continue
            if stripped.startswith('INT.') or stripped.startswith('EXT.'):
                in_dialogue = False
                continue
            dialogue_words += len(stripped.split())
            continue

        if stripped.startswith('(') and stripped.endswith(')'):
            continue

    return dialogue_words


def count_exchanges(content: str) -> int:
    """Count dialogue exchanges. Delegates to shared canonical implementation."""
    if _USE_SHARED:
        blocks = parse_dialogue_blocks(content)
        return _shared_count_exchanges(blocks)

    # Fallback: original implementation
    lines = content.split('\n')
    exchanges = 0

    for line in lines:
        stripped = line.strip()
        if stripped and stripped.isupper() and len(stripped) < 30:
            if not stripped.startswith('.') and ':' not in stripped:
                skip_words = ['ECU', 'CU', 'MCU', 'MS', 'WS', 'POV', 'SFX', 'VFX',
                             'INSERT', 'CONTINUOUS', 'LATER', 'PULL BACK', 'CLICK']
                if stripped not in skip_words and not stripped.startswith('PULL'):
                    exchanges += 1

    return exchanges


def detect_hook_type(content: str) -> str:
    """
    Detect if THE HOOK section has dialogue (dialogue) or is silent.
    Returns 'silent' or 'dialogue'.
    """
    hook_match = re.search(
        r'#\s*\[[\d:]+\s*-\s*[\d:]+\]\s*THE HOOK.*?\n(.*?)(?=#\s*\[[\d:]+|$)',
        content, re.DOTALL | re.IGNORECASE
    )

    if not hook_match:
        return "silent"

    hook_content = hook_match.group(1)

    # Check for character cues which indicate dialogue
    for line in hook_content.split('\n'):
        stripped = line.strip()
        if _USE_SHARED:
            if _shared_is_character_cue(stripped):
                return "dialogue"
        else:
            # Fallback: original detection
            if stripped and stripped.isupper() and len(stripped) < 30:
                if not stripped.startswith('INT.') and not stripped.startswith('EXT.'):
                    if not stripped.startswith('.') and ':' not in stripped:
                        skip_words = ['ECU', 'CU', 'MCU', 'MS', 'WS', 'POV', 'SFX', 'VFX',
                                     'INSERT', 'CONTINUOUS', 'PULL BACK']
                        if stripped not in skip_words and not stripped.startswith('PULL'):
                            return "dialogue"

    return "silent"


def detect_cliffhanger_type(content: str) -> str:
    """
    Detect cliffhanger type from metadata or content analysis.
    Returns 'mid-action' or 'aftermath'.
    """
    # Look for explicit cliffhanger type annotation in episode
    cliffhanger_match = re.search(
        r'(?:Cliffhanger|CLIFFHANGER)[:\s]*\*?\*?(mid-action|aftermath|MID-ACTION|AFTERMATH)',
        content, re.IGNORECASE
    )
    if cliffhanger_match:
        return cliffhanger_match.group(1).lower()

    # Look in THE CLIFFHANGER section header
    cliffhanger_section = re.search(
        r'#\s*\[[\d:]+\s*-\s*[\d:]+\]\s*THE CLIFFHANGER\s*[—–-]\s*(mid-action|aftermath)',
        content, re.IGNORECASE
    )
    if cliffhanger_section:
        return cliffhanger_section.group(1).lower()

    # Look for MID-ACTION or AFTERMATH markers
    if re.search(r'\*\*MID-ACTION\*\*', content, re.IGNORECASE):
        return "mid-action"
    if re.search(r'\*\*AFTERMATH\*\*', content, re.IGNORECASE):
        return "aftermath"

    # Default to mid-action if can't detect
    return "mid-action"


def extract_episode_title(content: str) -> str:
    """Extract episode title from header."""
    match = re.search(r'\[\[EPISODE\s+\d+:\s*(.+?)\]\]', content, re.IGNORECASE)
    return match.group(1).strip() if match else "Untitled"


def parse_treatment_threads(project_path: Path) -> list[dict]:
    """
    Parse thread index from treatment.md.
    Returns list of dicts with name, plant_eps, advance_eps, payoff_eps.
    """
    treatment = project_path / "treatment.md"
    if not treatment.is_file():
        return []
    text = treatment.read_text()
    threads = []
    in_table = False
    for line in text.split("\n"):
        if "THREAD INDEX" in line:
            in_table = True
            continue
        if in_table and line.startswith("|") and "---" not in line and "Thread" not in line:
            cols = [c.strip() for c in line.split("|")[1:-1]]
            if len(cols) >= 4:
                threads.append({
                    "name": cols[0],
                    "plant_eps": _parse_episode_refs(cols[1]),
                    "advance_eps": _parse_episode_refs(cols[2]),
                    "payoff_eps": _parse_episode_refs(cols[3]),
                })
        elif in_table and line.startswith("---"):
            break
    return threads


def _parse_episode_refs(text: str) -> list[int]:
    """Parse episode references like 'Ep 1, 9' or 'Ep 39-40' or 'Throughout' into list of ints."""
    if not text or text.strip() == "--" or text.strip().lower() == "throughout":
        return []
    nums = []
    # Find all numbers and ranges
    for part in re.split(r'[,;]', text):
        part = part.strip()
        range_match = re.match(r'(?:Ep\s*)?(\d+)\s*-\s*(\d+)', part)
        if range_match:
            start, end = int(range_match.group(1)), int(range_match.group(2))
            nums.extend(range(start, end + 1))
        else:
            ep_match = re.findall(r'(\d+)', part)
            nums.extend(int(n) for n in ep_match)
    return nums


# Cache treatment threads per project path
_treatment_threads_cache: dict[str, list[dict]] = {}


def extract_threads(content: str, ep_num: int = 0, project_path: Path | None = None) -> dict:
    """
    Extract thread information for an episode.

    Two-tier detection:
    1. Treatment schedule: check if this episode number matches any thread's
       plant/advance/payoff schedule from treatment.md
    2. Explicit markers: look for [PLANT: X], [ADVANCE: X], [PAYOFF: X] in episode text

    The treatment schedule is the primary source. Markers are supplementary.
    """
    threads = {
        "planted": [],
        "advanced": [],
        "paid_off": []
    }

    # Tier 1: Treatment schedule lookup
    if project_path and ep_num > 0:
        cache_key = str(project_path)
        if cache_key not in _treatment_threads_cache:
            _treatment_threads_cache[cache_key] = parse_treatment_threads(project_path)
        for t in _treatment_threads_cache[cache_key]:
            tid = re.sub(r'[^A-Z0-9]', '_', t["name"].upper()).strip('_')
            if ep_num in t["plant_eps"]:
                threads["planted"].append(tid)
            if ep_num in t["advance_eps"]:
                threads["advanced"].append(tid)
            if ep_num in t["payoff_eps"]:
                threads["paid_off"].append(tid)

    # Tier 2: Explicit markers in episode text (supplementary)
    # Normalize the same way as Tier 1: uppercase, non-alphanumeric → underscore
    def _normalize_tid(raw: str) -> str:
        return re.sub(r'[^A-Z0-9]', '_', raw.strip().upper()).strip('_')

    plant_matches = re.findall(r'\[PLANT:\s*([^\]]+)\]', content, re.IGNORECASE)
    for m in plant_matches:
        tid = _normalize_tid(m)
        if tid not in threads["planted"]:
            threads["planted"].append(tid)

    advance_matches = re.findall(r'\[ADVANCE:\s*([^\]]+)\]', content, re.IGNORECASE)
    for m in advance_matches:
        tid = _normalize_tid(m)
        if tid not in threads["advanced"]:
            threads["advanced"].append(tid)

    payoff_matches = re.findall(r'\[PAYOFF:\s*([^\]]+)\]', content, re.IGNORECASE)
    for m in payoff_matches:
        tid = _normalize_tid(m)
        if tid not in threads["paid_off"]:
            threads["paid_off"].append(tid)

    return threads


def extract_emotional_beat(content: str) -> str | None:
    """
    Extract emotional beat from episode if present.
    Looks for [BEAT: BEAT_NAME] markers.
    """
    match = re.search(r'\[BEAT:\s*(\w+)\]', content, re.IGNORECASE)
    return match.group(1).upper() if match else None


def extract_key_moment(content: str) -> str:
    """Extract THE MOMENT from episode (brief description)."""
    # Look for THE MOMENT marker
    match = re.search(r'THE MOMENT[:\s]*(.+?)(?:\n|$)', content, re.IGNORECASE)
    if match:
        return match.group(1).strip()[:100]  # Truncate to 100 chars

    # Or look in metadata section
    match = re.search(r'\*\*THE MOMENT\*\*[:\s]*(.+?)(?:\n|$)', content, re.IGNORECASE)
    if match:
        return match.group(1).strip()[:100]

    return ""


def extract_cliffhanger_image(content: str) -> str:
    """Extract cliffhanger visual from episode (brief description)."""
    # Look in THE CLIFFHANGER section for last action line
    cliffhanger_match = re.search(
        r'#\s*\[[\d:]+\s*-\s*[\d:]+\]\s*THE CLIFFHANGER.*?\n(.*?)(?=---|$)',
        content, re.DOTALL | re.IGNORECASE
    )

    if cliffhanger_match:
        section = cliffhanger_match.group(1)
        # Get last non-empty action line (not dialogue)
        lines = [l.strip() for l in section.split('\n') if l.strip()]
        for line in reversed(lines):
            # Skip metadata, dialogue, character names
            if not line.startswith('#') and not line.startswith('*') and not line.isupper():
                return line[:100]  # Truncate to 100 chars

    return ""


def extract_location(content: str) -> str:
    """Extract last scene location from episode."""
    # Find all scene headings
    matches = re.findall(r'(?:INT\.|EXT\.)\s+(.+?)(?:\s*-|\n|$)', content)
    if matches:
        return matches[-1].strip()[:50]
    return ""


def analyze_episode(filepath: Path, ep_num: int, project_path: Path | None = None) -> dict:
    """Analyze a single episode and return summary data."""
    if not filepath.exists():
        return {
            "number": ep_num,
            "error": "File not found"
        }

    content = filepath.read_text()
    total_words = count_words(content)
    dialogue_words = count_dialogue_words(content)
    dialogue_percent = round(dialogue_words / total_words * 100, 1) if total_words > 0 else 0

    return {
        "number": ep_num,
        "title": extract_episode_title(content),
        "threads": extract_threads(content, ep_num, project_path),
        "emotional_beat": extract_emotional_beat(content),
        "hook_type": detect_hook_type(content),
        "cliffhanger_type": detect_cliffhanger_type(content),
        "word_count": total_words,
        "dialogue_percent": dialogue_percent,
        "exchanges": count_exchanges(content),
        "key_moment": extract_key_moment(content),
        "cliffhanger_image": extract_cliffhanger_image(content)
    }


def generate_batch_summary(project_path: Path, batch_num: int) -> dict:
    """Generate complete batch summary."""
    episodes_dir = project_path / "episodes"
    ep_start = (batch_num - 1) * GENERATION_BATCH_SIZE + 1
    ep_end = batch_num * GENERATION_BATCH_SIZE

    episode_data = []
    all_planted = []
    all_advanced = []
    all_paid_off = []
    all_beats = []
    hook_counts = {"silent": 0, "dialogue": 0}
    cliff_counts = {"mid_action": 0, "aftermath": 0}
    issues = []

    for ep_num in range(ep_start, ep_end + 1):
        ep_file = episodes_dir / f"ep_{ep_num:03d}.md"
        ep_data = analyze_episode(ep_file, ep_num, project_path)

        if "error" in ep_data:
            issues.append({
                "episode": ep_num,
                "issue_type": "missing_file",
                "description": ep_data["error"],
                "severity": "error"
            })
            continue

        episode_data.append(ep_data)

        # Aggregate threads
        all_planted.extend(ep_data["threads"]["planted"])
        all_advanced.extend(ep_data["threads"]["advanced"])
        all_paid_off.extend(ep_data["threads"]["paid_off"])

        # Aggregate beats
        if ep_data["emotional_beat"]:
            all_beats.append(ep_data["emotional_beat"])

        # Aggregate patterns
        if ep_data["hook_type"] == "silent":
            hook_counts["silent"] += 1
        else:
            hook_counts["dialogue"] += 1

        if ep_data["cliffhanger_type"] == "mid-action":
            cliff_counts["mid_action"] += 1
        else:
            cliff_counts["aftermath"] += 1

    # Get last episode for continuity notes
    last_ep = episode_data[-1] if episode_data else None
    continuity = {}
    if last_ep:
        last_file = episodes_dir / f"ep_{ep_end:03d}.md"
        if last_file.exists():
            content = last_file.read_text()
            continuity = {
                "last_location": extract_location(content),
                "open_tension": last_ep.get("cliffhanger_image", ""),
                "next_batch_setup": f"Continue from Episode {ep_end} cliffhanger"
            }

    summary = {
        "batch": batch_num,
        "timestamp": datetime.now().isoformat(),
        "episodes": episode_data,
        "batch_summary": {
            "threads_planted_this_batch": list(set(all_planted)),
            "threads_advanced_this_batch": list(set(all_advanced)),
            "threads_paid_off_this_batch": list(set(all_paid_off)),
            "emotional_beats_hit_this_batch": all_beats,
            "pattern_distribution": {
                "hooks": hook_counts,
                "cliffhangers": cliff_counts
            },
            "validation_passed": len(issues) == 0,
            "issues_flagged": issues
        },
        "continuity_notes": continuity
    }

    return summary


def main():
    if len(sys.argv) < 3:
        print("Usage: python3 generate_batch_summary.py <project_path> <batch_number>")
        print("Example: python3 generate_batch_summary.py ./leviathan 3")
        sys.exit(1)

    project_path = Path(sys.argv[1]).resolve()
    batch_num = int(sys.argv[2])

    if not project_path.exists():
        print(f"Error: Project path does not exist: {project_path}")
        sys.exit(1)

    # Generate summary
    summary = generate_batch_summary(project_path, batch_num)

    # Ensure state directory exists
    state_dir = ProjectPaths.from_root(project_path).state_dir
    state_dir.mkdir(parents=True, exist_ok=True)

    # Write summary
    output_path = state_dir / f"batch_{batch_num:02d}_summary.json"
    with open(output_path, 'w') as f:
        json.dump(summary, f, indent=2)

    # Print summary
    ep_start = (batch_num - 1) * GENERATION_BATCH_SIZE + 1
    ep_end = batch_num * GENERATION_BATCH_SIZE

    print(f"\n{'='*60}")
    print(f"BATCH SUMMARY GENERATED: Batch {batch_num}")
    print(f"{'='*60}")
    print(f"\nEpisodes: {ep_start}-{ep_end}")
    print(f"Output: {output_path}")

    print(f"\nThreads:")
    print(f"  Planted: {summary['batch_summary']['threads_planted_this_batch']}")
    print(f"  Advanced: {summary['batch_summary']['threads_advanced_this_batch']}")
    print(f"  Paid off: {summary['batch_summary']['threads_paid_off_this_batch']}")

    print(f"\nEmotional beats hit: {summary['batch_summary']['emotional_beats_hit_this_batch']}")

    print(f"\nPattern distribution:")
    hooks = summary['batch_summary']['pattern_distribution']['hooks']
    cliffs = summary['batch_summary']['pattern_distribution']['cliffhangers']
    print(f"  Hooks: {hooks['silent']} silent / {hooks['dialogue']} dialogue")
    print(f"  Cliffhangers: {cliffs['mid_action']} mid-action / {cliffs['aftermath']} aftermath")

    if summary['batch_summary']['issues_flagged']:
        print(f"\nIssues flagged: {len(summary['batch_summary']['issues_flagged'])}")
        for issue in summary['batch_summary']['issues_flagged']:
            print(f"  - Ep {issue['episode']}: {issue['description']}")

    print(f"\n{'='*60}\n")


if __name__ == "__main__":
    main()
