#!/usr/bin/env python3
"""
Extract Treatment Batch Tool
Extracts a batch of episodes from treatment.md for generation.

Usage:
    python extract_treatment_batch.py <treatment_path> <start_ep> <end_ep>
    python extract_treatment_batch.py <treatment_path> <start_ep> <end_ep> --output batch.md
    python extract_treatment_batch.py <treatment_path> <start_ep> <end_ep> --context

Options:
    --output, -o    Write to file instead of stdout
    --context, -c   Include context window (previously, this batch, coming)

Examples:
    python extract_treatment_batch.py olympus/treatment.md 11 15
    python extract_treatment_batch.py olympus/treatment.md 11 15 --context
    python extract_treatment_batch.py olympus/treatment.md 11 15 -o batch_3.md
"""

import re
import sys
import argparse
from pathlib import Path

# Import TOTAL_EPISODES from shared constants
try:
    sys.path.insert(0, str(Path(__file__).parent))
    from engine_constants import TOTAL_EPISODES
    sys.path.pop(0)
except ImportError:
    TOTAL_EPISODES = 60


def parse_episodes(content):
    """Parse treatment content and return episode dict keyed by number."""
    episodes = {}

    # Episode header pattern
    episode_header_pattern = re.compile(
        r'###\s*Episode\s*(\d+):\s*"([^"]+)"',
        re.IGNORECASE
    )

    # Split by episode headers, keeping the delimiter
    blocks = re.split(r'(?=###\s*Episode\s*\d+)', content)

    for block in blocks:
        match = episode_header_pattern.search(block)
        if match:
            ep_num = int(match.group(1))
            # Clean up the block - remove trailing --- separators
            cleaned = block.strip()
            if cleaned.endswith('---'):
                cleaned = cleaned[:-3].strip()
            episodes[ep_num] = cleaned

    return episodes


def get_episode_summary(episode_text):
    """Extract a one-sentence summary from episode prose for context."""
    # Find the prose paragraph (after metadata lines, before cliffhanger)
    lines = episode_text.split('\n')
    prose_lines = []
    in_prose = False

    for line in lines:
        # Skip header and metadata lines
        if line.startswith('###') or '**Sequence:**' in line or '**Threads:**' in line or '**THE MOMENT:**' in line:
            in_prose = True
            continue
        if '**[CLIFFHANGER:' in line:
            break
        if in_prose and line.strip() and not line.startswith('---'):
            prose_lines.append(line.strip())

    if prose_lines:
        # Take first sentence
        full_prose = ' '.join(prose_lines)
        # Split on sentence boundaries
        sentences = re.split(r'(?<=[.!?])\s+', full_prose)
        if sentences:
            return sentences[0]

    return "Episode content."


def extract_context(episodes, start_ep, end_ep):
    """Generate context window for the batch."""
    context = []

    # PREVIOUSLY (what happened before this batch)
    if start_ep > 1:
        prev_ep = start_ep - 1
        if prev_ep in episodes:
            prev_summary = get_episode_summary(episodes[prev_ep])
            context.append(f"**Previously (Ep {prev_ep}):** {prev_summary}")

    # THIS BATCH
    batch_eps = [ep for ep in range(start_ep, end_ep + 1) if ep in episodes]
    if batch_eps:
        first_summary = get_episode_summary(episodes[batch_eps[0]])
        last_summary = get_episode_summary(episodes[batch_eps[-1]])
        context.append(f"**This batch (Ep {start_ep}-{end_ep}):** From \"{first_summary[:50]}...\" to \"{last_summary[:50]}...\"")

    # COMING (where this batch is heading)
    if end_ep < 60:
        next_ep = end_ep + 1
        if next_ep in episodes:
            next_summary = get_episode_summary(episodes[next_ep])
            context.append(f"**Coming (Ep {next_ep}):** {next_summary}")

    return context


def extract_batch(treatment_path, start_ep, end_ep, include_context=False):
    """Extract episodes from treatment.md."""
    content = Path(treatment_path).read_text(encoding='utf-8')
    episodes = parse_episodes(content)

    if not episodes:
        print(f"Warning: No episodes found in {treatment_path}", file=sys.stderr)
        return ""

    output_parts = []

    # Add header
    output_parts.append(f"## Treatment Entries (Episodes {start_ep}-{end_ep})")
    output_parts.append("")

    # Add context if requested
    if include_context:
        output_parts.append("### CONTEXT")
        output_parts.append("")
        context = extract_context(episodes, start_ep, end_ep)
        for line in context:
            output_parts.append(line)
        output_parts.append("")
        output_parts.append("---")
        output_parts.append("")

    # Extract requested episodes
    extracted = []
    for ep_num in range(start_ep, end_ep + 1):
        if ep_num in episodes:
            extracted.append(episodes[ep_num])
        else:
            print(f"Warning: Episode {ep_num} not found", file=sys.stderr)

    if not extracted:
        print(f"Warning: No episodes found in range {start_ep}-{end_ep}", file=sys.stderr)
        return ""

    output_parts.append("\n\n---\n\n".join(extracted))

    return "\n".join(output_parts)


def main():
    parser = argparse.ArgumentParser(
        description="Extract treatment batch for generation",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
    python extract_treatment_batch.py olympus/treatment.md 11 15
    python extract_treatment_batch.py olympus/treatment.md 11 15 --context
    python extract_treatment_batch.py olympus/treatment.md 11 15 -o batch_3.md
        """
    )
    parser.add_argument("treatment", help="Path to treatment.md")
    parser.add_argument("start", type=int, help="Start episode number")
    parser.add_argument("end", type=int, help="End episode number")
    parser.add_argument("--output", "-o", help="Output file (default: stdout)")
    parser.add_argument("--context", "-c", action="store_true",
                        help="Include context window (previously, this batch, coming)")

    args = parser.parse_args()

    # Validate inputs
    if not Path(args.treatment).exists():
        print(f"Error: File not found: {args.treatment}", file=sys.stderr)
        sys.exit(1)

    if args.start < 1 or args.end > TOTAL_EPISODES:
        print(f"Error: Episode range must be between 1 and {TOTAL_EPISODES}", file=sys.stderr)
        sys.exit(1)

    if args.start > args.end:
        print(f"Error: Start episode must be <= end episode", file=sys.stderr)
        sys.exit(1)

    # Extract batch
    result = extract_batch(args.treatment, args.start, args.end, args.context)

    if args.output:
        Path(args.output).write_text(result, encoding='utf-8')
        print(f"Extracted episodes {args.start}-{args.end} to {args.output}")
    else:
        print(result)


if __name__ == "__main__":
    main()
