#!/usr/bin/env python3
"""
Fountain Screenplay Audio Reader — ElevenLabs + Qwen3 TTS

Parse .fountain screenplays, assign distinct voices per character,
generate per-episode MP3s via ElevenLabs API or Qwen3-TTS (local).

Commands:
    fountain_reader.py <project> init            # Create voice_config.yaml
    fountain_reader.py <project> dry-run         # Char counts, no API calls
    fountain_reader.py <project> generate        # Generate MP3 audio files
    fountain_reader.py <project> voice-direct    # Analyze parenthetical coverage
    fountain_reader.py <project> budget          # Monthly credit usage
    fountain_reader.py <project> voices          # List ElevenLabs voices

Flags:
    --episode N          Single episode
    --episodes N-M       Episode range
    --dialogue-only      Skip narration (saves ~60% credits)
    --concat             Single output file
    --force              Regenerate existing files

Environment:
    ELEVEN_API_KEY       ElevenLabs API key (required for generate/voices with elevenlabs engine)

Engine selection:
    Set "engine:" in voice_config.yaml:
      elevenlabs   — ElevenLabs API (requires ELEVEN_API_KEY)
      qwen3        — Qwen3 CustomVoice, static per-character instruct
      qwen3_clone  — Qwen3 ICL voice cloning from ref WAVs (best identity consistency)

Dependencies:
    pip install elevenlabs pydub pyyaml
    pip install qwen-tts soundfile (for qwen3/qwen3_clone engines)
    ffmpeg (system — already installed for ComfyUI)
"""

import argparse
import hashlib
import io
import json
import os
import re
import sys
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path

from recoil.core.model_profiles import get_model

try:
    import yaml
except ImportError:
    yaml = None

try:
    from elevenlabs.client import ElevenLabs as ElevenLabsClient
    from elevenlabs import VoiceSettings
except ImportError:
    ElevenLabsClient = None
    VoiceSettings = None

try:
    from pydub import AudioSegment
except ImportError:
    AudioSegment = None


# ── Data ──

@dataclass
class Segment:
    type: str          # "dialogue", "action", "scene_heading"
    text: str          # Speakable text
    character: str     # Character name or "NARRATOR"
    episode: int
    section: str       # THE HOOK, THE SETUP, etc.
    instruct: str = "" # Per-line emotional direction (from fountain parenthetical)


# ── Constants ──

MONTHLY_CHAR_BUDGET = 60_000
DEFAULT_MODEL = get_model("elevenlabs", "tts")
DEFAULT_FORMAT = "mp3_44100_128"
RATE_PER_1K_CHARS = 0.30

CAMERA_DIRS = frozenset({
    "POV", "ECU", "CU", "MCU", "MS", "MLS", "LS", "ELS",
    "ANGLE", "CLOSE", "WIDE", "INSERT", "INTERCUT", "FLASHBACK",
    "FADE", "CUT", "SMASH", "MATCH", "SLOW", "PUSH", "PULL",
    "TRACKING", "CRANE", "TILT", "PAN", "MONTAGE", "SERIES",
})


# ── Fountain Parser ──

def find_fountain_file(project_path, config=None):
    """Find the compiled .fountain file in a project directory."""
    if config and config.get("fountain_file"):
        p = project_path / config["fountain_file"]
        if p.exists():
            return p
    candidates = sorted(project_path.glob("*.fountain"),
                        key=lambda p: p.stat().st_mtime, reverse=True)
    complete = [f for f in candidates if "COMPLETE" in f.name.upper()]
    if complete:
        return complete[0]
    if candidates:
        return candidates[0]
    raise FileNotFoundError(f"No .fountain files in {project_path}")


def _is_character_name(line):
    """Return cleaned character name if line is a character cue, else None."""
    s = line.strip()
    if not s or len(s) > 50:
        return None

    # Strip parenthetical extensions: (O.S.), (V.O.), (CONT'D)
    name = re.sub(r"\s*\([^)]*\)\s*$", "", s).strip()
    if not name:
        return None

    # Must be all uppercase + allowed punctuation
    if not re.match(r"^[A-Z][A-Z0-9 .'\-]*$", name):
        return None

    words = name.split()
    if len(words) > 4:
        return None

    # Filter camera directions
    if words[0].rstrip(".") in CAMERA_DIRS:
        return None

    # Filter scene headings
    if name.startswith(("INT.", "EXT.", "INT/", "I/E")):
        return None

    return name


def parse_fountain(fountain_path):
    """Parse a compiled .fountain file into ordered Segments."""
    lines = fountain_path.read_text(encoding="utf-8").split("\n")
    segments = []
    episode = 0
    section = ""
    in_title_page = True

    ep_re = re.compile(r"^\[\[EPISODE\s+(\d+):\s*(.+?)\]\]$")
    sec_re = re.compile(r"^#\s*\[[\d:]+\s*-\s*[\d:]+\]\s*(.+)$")
    scene_re = re.compile(r"^(INT\.|EXT\.|INT\./EXT\.|I/E)", re.I)
    paren_re = re.compile(r"^\(.*\)\s*$")
    break_re = re.compile(r"^={3,}\s*$")

    i, n = 0, len(lines)
    while i < n:
        s = lines[i].strip()

        # Title page — skip until first ===
        if in_title_page:
            if break_re.match(s):
                in_title_page = False
            i += 1
            continue

        # Skip blanks and page breaks
        if not s or break_re.match(s):
            i += 1
            continue

        # Episode marker
        m = ep_re.match(s)
        if m:
            episode = int(m.group(1))
            i += 1
            continue

        # Section header (# [00:00 - 00:05] THE HOOK)
        m = sec_re.match(s)
        if m:
            section = m.group(1).strip()
            i += 1
            continue

        # Other markdown headers
        if s.startswith("#"):
            i += 1
            continue

        # Centered text (> ... <)
        if s.startswith(">") and s.endswith("<"):
            i += 1
            continue

        # Scene heading
        if scene_re.match(s):
            text = re.sub(r"\s*-\s*CONTINUOUS\s*$", "", s, flags=re.I)
            text = _normalize_tts_text(text)
            segments.append(Segment("scene_heading", text, "NARRATOR", episode, section))
            i += 1
            continue

        # Character cue — look ahead for dialogue
        char = _is_character_name(s)
        if char:
            j = i + 1
            dlg = []
            instruct = ""
            while j < n:
                ls = lines[j].strip()
                if not ls:
                    break
                if paren_re.match(ls):
                    # Capture first parenthetical as per-line emotional direction
                    if not instruct:
                        instruct = ls[1:-1].strip()
                    j += 1
                    continue
                dlg.append(ls)
                j += 1
            if dlg:
                segments.append(Segment(
                    "dialogue", _normalize_tts_text(" ".join(dlg)), char, episode, section,
                    instruct=instruct))
                i = j
                continue
            # No dialogue found → fall through to action

        # Action line
        if len(s) > 2:
            segments.append(Segment("action", _normalize_tts_text(s), "NARRATOR", episode, section))
        i += 1

    return segments


def _normalize_tts_text(text):
    """Normalize text for TTS pronunciation.

    Expands abbreviations and fixes formatting that confuses speech models.
    """
    # Scene heading abbreviations
    text = re.sub(r'\bINT\.\s*', 'Interior ', text)
    text = re.sub(r'\bEXT\.\s*', 'Exterior ', text)
    text = re.sub(r'\bINT\./EXT\.\s*', 'Interior exterior ', text)
    text = re.sub(r'\bI/E\b', 'Interior exterior', text)
    # Hyphenated nicknames — remove hyphen so TTS reads as two words
    # (hyphens cause pauses in some TTS models)
    text = re.sub(r'(?i)\bchrome-boy\b', 'Chrome Boy', text)
    return text.strip()


# ── Filtering & Analysis ──

def filter_segments(segments, dialogue_only=False, episodes=None):
    """Filter segments by mode and episode range."""
    result = segments
    if episodes:
        result = [s for s in result if s.episode in episodes]
    if dialogue_only:
        result = [s for s in result if s.type == "dialogue"]
    return result


def get_characters(segments):
    """Count dialogue lines per character (excludes NARRATOR)."""
    counts = {}
    for s in segments:
        if s.type == "dialogue":
            counts[s.character] = counts.get(s.character, 0) + 1
    return dict(sorted(counts.items(), key=lambda x: -x[1]))


def chars_by_voice(segments):
    """Total character count per voice."""
    counts = {}
    for s in segments:
        counts[s.character] = counts.get(s.character, 0) + len(s.text)
    return dict(sorted(counts.items(), key=lambda x: -x[1]))


def chars_by_episode(segments):
    """Total character count per episode."""
    counts = {}
    for s in segments:
        counts[s.episode] = counts.get(s.episode, 0) + len(s.text)
    return dict(sorted(counts.items()))


# ── Voice Config ──

def config_path(project_path):
    return project_path / "audio" / "voice_config.yaml"


def credit_log_path(project_path):
    return project_path / "audio" / "credit_log.json"


def load_config(project_path):
    """Load voice_config.yaml."""
    if yaml is None:
        print("ERROR: pyyaml required. pip install pyyaml", file=sys.stderr)
        sys.exit(1)
    cp = config_path(project_path)
    if not cp.exists():
        print(f"ERROR: No voice config. Run 'init' first: {cp}", file=sys.stderr)
        sys.exit(1)
    with open(cp) as f:
        return yaml.safe_load(f)


def resolve_voice(character, config):
    """Resolve character name to (voice_id_or_speaker, voice_settings dict).

    For elevenlabs engine: returns (voice_id, {stability, similarity_boost})
    For qwen3 engine: returns (speaker_name, {instruct})
    """
    engine = config.get("engine", "elevenlabs")

    if engine == "qwen3":
        return _resolve_voice_qwen3(character, config)
    else:
        return _resolve_voice_elevenlabs(character, config)


def _resolve_voice_elevenlabs(character, config):
    """Resolve character to ElevenLabs voice_id."""
    default_settings = {"stability": 0.5, "similarity_boost": 0.75}

    # Narrator
    if character == "NARRATOR":
        nr = config.get("narrator", {})
        return nr.get("voice_id", ""), {
            "stability": nr.get("stability", 0.5),
            "similarity_boost": nr.get("similarity_boost", 0.75),
        }

    # Named character
    characters = config.get("characters", {})
    if character in characters:
        ch = characters[character]
        return ch.get("voice_id", ""), {
            "stability": ch.get("stability", 0.5),
            "similarity_boost": ch.get("similarity_boost", 0.75),
        }

    # Side voice pool — deterministic assignment
    pool = config.get("side_voice_pool", [])
    if pool:
        idx = int(hashlib.md5(character.encode()).hexdigest(), 16) % len(pool)
        entry = pool[idx]
        vid = entry.get("voice_id", "") if isinstance(entry, dict) else entry
        return vid, default_settings

    return "", default_settings


def _resolve_voice_qwen3(character, config):
    """Resolve character to Qwen3 speaker name + instruct."""
    default_settings = {"instruct": ""}

    # Narrator
    if character == "NARRATOR":
        nr = config.get("narrator", {})
        return nr.get("speaker", ""), {
            "instruct": nr.get("instruct", ""),
        }

    # Named character
    characters = config.get("characters", {})
    if character in characters:
        ch = characters[character]
        return ch.get("speaker", ""), {
            "instruct": ch.get("instruct", ""),
        }

    # Side voice pool — deterministic assignment
    pool = config.get("side_voice_pool", [])
    if pool:
        idx = int(hashlib.md5(character.encode()).hexdigest(), 16) % len(pool)
        entry = pool[idx]
        speaker = entry.get("speaker", "") if isinstance(entry, dict) else entry
        return speaker, default_settings

    return "", default_settings


# ── Credit Log ──

def load_credit_log(project_path):
    lp = credit_log_path(project_path)
    if lp.exists():
        try:
            with open(lp) as f:
                return json.load(f).get("entries", [])
        except (json.JSONDecodeError, IOError):
            pass
    return []


def save_credit_log(project_path, entries):
    lp = credit_log_path(project_path)
    lp.parent.mkdir(parents=True, exist_ok=True)
    with open(lp, "w") as f:
        json.dump({
            "project": project_path.name,
            "updated": datetime.now(timezone.utc).isoformat(),
            "entries": entries,
        }, f, indent=2)


def log_credits(project_path, episode, character, char_count):
    """Append a credit entry to credit_log.json."""
    entries = load_credit_log(project_path)
    entries.append({
        "timestamp": datetime.now(timezone.utc).isoformat(),
        "episode": episode,
        "character": character,
        "chars": char_count,
    })
    save_credit_log(project_path, entries)


# ── Cost Tracker Integration ──

def log_to_cost_tracker(project_path, episode, char_count, segment_count,
                        provider="elevenlabs"):
    """Log voice cost to the production cost tracker (if available)."""
    try:
        tools_dir = Path(__file__).resolve().parent
        if str(tools_dir) not in sys.path:
            sys.path.insert(0, str(tools_dir))
        from cost_tracker import CostTracker
        tracker = CostTracker(str(project_path))
        cost = char_count / 1000 * RATE_PER_1K_CHARS if provider == "elevenlabs" else 0.0
        tracker.log(
            category="voice",
            provider=provider,
            model="tts_api" if provider == "elevenlabs" else "qwen3_local",
            success=True,
            cost_override=cost,
            episode=episode,
            detail=f"{char_count} chars, {segment_count} segments",
        )
    except ImportError:
        pass


# ── Commands ──

def cmd_init(project_path):
    """Scan fountain file and create voice_config.yaml template."""
    if yaml is None:
        print("ERROR: pyyaml required. pip install pyyaml", file=sys.stderr)
        sys.exit(1)

    fountain = find_fountain_file(project_path)
    print(f"Parsing: {fountain.name}")

    segments = parse_fountain(fountain)
    characters = get_characters(segments)

    if not characters:
        print("No characters found.")
        return

    # Classify: main (10+ lines) vs side
    mains = {k: v for k, v in characters.items() if v >= 10}
    sides = {k: v for k, v in characters.items() if v < 10}

    # Write config with comments (manual YAML for readability)
    cp = config_path(project_path)
    cp.parent.mkdir(parents=True, exist_ok=True)

    with open(cp, "w") as f:
        f.write(f"# Voice config for {project_path.name}\n")
        f.write(f"# Generated {datetime.now().strftime('%Y-%m-%d %H:%M')}\n")
        f.write(f"# Fill in voice_id fields from ElevenLabs Voice Library\n")
        f.write(f"# Run: fountain_reader.py {project_path.name} voices\n\n")
        f.write(f"project: {project_path.name}\n")
        f.write(f"fountain_file: {fountain.name}\n")
        f.write(f"model: {DEFAULT_MODEL}\n")
        f.write(f"output_format: {DEFAULT_FORMAT}\n\n")
        f.write("narrator:\n")
        f.write('  voice_id: ""    # FILL IN\n')
        f.write("  stability: 0.5\n")
        f.write("  similarity_boost: 0.75\n\n")
        f.write("characters:\n")
        for name, count in mains.items():
            f.write(f"  {name}:\n")
            f.write(f'    voice_id: ""    # FILL IN — {count} lines\n')
        f.write("\n")
        if sides:
            f.write(f"# Side characters ({len(sides)}): "
                    f"{', '.join(sides.keys())}\n")
            f.write("# Assigned from pool via hash(name) % pool_size\n")
        f.write("side_voice_pool:\n")
        f.write('  - voice_id: ""    # FILL IN — pool voice 1\n')
        f.write('  - voice_id: ""    # FILL IN — pool voice 2\n')
        f.write('  - voice_id: ""    # FILL IN — pool voice 3\n\n')
        f.write("silence_ms:\n")
        f.write("  between_segments: 200\n")
        f.write("  section_break: 500\n")
        f.write("  episode_break: 1000\n")

    print(f"\nCreated: {cp}")
    print(f"\nCharacters found:")
    print(f"  Main ({len(mains)}):")
    for name, count in mains.items():
        print(f"    {name}: {count} lines")
    if sides:
        print(f"  Side ({len(sides)}):")
        for name, count in sides.items():
            print(f"    {name}: {count} lines")
    print(f"\nNext: fill in voice_ids, then run 'dry-run'.")


def cmd_dry_run(project_path, episodes, dialogue_only):
    """Show character/credit counts without API calls."""
    fountain = find_fountain_file(project_path)
    print(f"Parsing: {fountain.name}")

    all_segs = parse_fountain(fountain)
    segs = filter_segments(all_segs, dialogue_only=dialogue_only, episodes=episodes)

    if not segs:
        print("No segments found for the given filters.")
        return

    mode = "dialogue-only" if dialogue_only else "full narration"
    total = sum(len(s.text) for s in segs)
    ep_chars = chars_by_episode(segs)
    voice_chars = chars_by_voice(segs)

    print(f"\nMode: {mode}")
    print(f"Episodes: {len(ep_chars)}")
    print(f"Segments: {len(segs)}")
    print(f"Total chars: {total:,}")
    cost = total / 1000 * RATE_PER_1K_CHARS
    print(f"Estimated cost: ${cost:.2f} (at ${RATE_PER_1K_CHARS}/1K chars)")
    pct = total / MONTHLY_CHAR_BUDGET * 100
    print(f"Budget usage: {total:,} / {MONTHLY_CHAR_BUDGET:,} ({pct:.0f}%)")

    print(f"\nBy Episode:")
    for ep, chars in ep_chars.items():
        print(f"  Ep {ep:3d}: {chars:5,} chars")

    avg = total // len(ep_chars) if ep_chars else 0
    print(f"  Average: {avg:,} chars/episode")

    print(f"\nBy Voice:")
    for voice, chars in voice_chars.items():
        pct = chars / total * 100
        print(f"  {voice:20s}: {chars:6,} chars ({pct:.0f}%)")


def cmd_generate(project_path, episodes, dialogue_only, concat, force):
    """Generate MP3 audio files via ElevenLabs or Qwen3 TTS."""
    if AudioSegment is None:
        print("ERROR: pip install pydub", file=sys.stderr)
        sys.exit(1)

    config = load_config(project_path)
    engine = config.get("engine", "elevenlabs")

    if engine == "qwen3":
        _generate_qwen3(project_path, config, episodes, dialogue_only, concat, force)
    elif engine == "qwen3_clone":
        _generate_qwen3_clone(project_path, config, episodes, dialogue_only, concat, force)
    else:
        _generate_elevenlabs(project_path, config, episodes, dialogue_only, concat, force)


def _generate_elevenlabs(project_path, config, episodes, dialogue_only, concat, force):
    """Generate MP3 audio files via ElevenLabs TTS."""
    if ElevenLabsClient is None:
        print("ERROR: pip install elevenlabs", file=sys.stderr)
        sys.exit(1)

    api_key = os.environ.get("ELEVEN_API_KEY")
    if not api_key:
        print("ERROR: Set ELEVEN_API_KEY environment variable.", file=sys.stderr)
        sys.exit(1)

    fountain = find_fountain_file(project_path, config)
    all_segs = parse_fountain(fountain)
    segs = filter_segments(all_segs, dialogue_only=dialogue_only, episodes=episodes)

    if not segs:
        print("No segments to generate.")
        return

    # Check for empty voice_ids
    missing = []
    nr_vid = config.get("narrator", {}).get("voice_id", "")
    if not nr_vid and not dialogue_only:
        missing.append("narrator")
    for name, ch in config.get("characters", {}).items():
        if not ch.get("voice_id", ""):
            missing.append(name)
    pool = config.get("side_voice_pool", [])
    pool_empty = all(not (e.get("voice_id", "") if isinstance(e, dict) else e)
                     for e in pool) if pool else True
    if pool_empty:
        missing.append("side_voice_pool (all empty)")

    if missing:
        print(f"WARNING: Missing voice_ids for: {', '.join(missing)}")
        print("Fill in voice_config.yaml before generating.\n")

    client = ElevenLabsClient(api_key=api_key)
    model_id = config.get("model", DEFAULT_MODEL)
    output_format = config.get("output_format", DEFAULT_FORMAT)
    silence_cfg = config.get("silence_ms", {})
    gap_ms = silence_cfg.get("between_segments", 200)
    section_ms = silence_cfg.get("section_break", 500)
    episode_ms = silence_cfg.get("episode_break", 1000)

    # Group by episode
    ep_groups = {}
    for s in segs:
        ep_groups.setdefault(s.episode, []).append(s)

    output_dir = project_path / "audio" / "episodes"
    output_dir.mkdir(parents=True, exist_ok=True)

    all_audio = []
    total_chars_generated = 0

    for ep_num in sorted(ep_groups.keys()):
        ep_segs = ep_groups[ep_num]
        out_file = output_dir / f"ep_{ep_num:03d}.mp3"

        if out_file.exists() and not force:
            print(f"  Ep {ep_num:3d}: exists (use --force to regenerate)")
            if concat:
                all_audio.append(AudioSegment.from_mp3(str(out_file)))
                all_audio.append(AudioSegment.silent(duration=episode_ms))
            continue

        print(f"  Ep {ep_num:3d}: generating ({len(ep_segs)} segments)...",
              end="", flush=True)

        ep_audio = AudioSegment.empty()
        ep_chars = 0
        prev_section = None

        for seg in ep_segs:
            voice_id, voice_settings = resolve_voice(seg.character, config)
            if not voice_id:
                continue

            # Silence gap
            if prev_section is not None and seg.section != prev_section:
                ep_audio += AudioSegment.silent(duration=section_ms)
            elif len(ep_audio) > 0:
                ep_audio += AudioSegment.silent(duration=gap_ms)
            prev_section = seg.section

            # TTS call
            try:
                audio_iter = client.text_to_speech.convert(
                    voice_id=voice_id,
                    text=seg.text,
                    model_id=model_id,
                    output_format=output_format,
                    voice_settings=VoiceSettings(
                        stability=voice_settings["stability"],
                        similarity_boost=voice_settings["similarity_boost"],
                    ),
                )
                audio_bytes = b"".join(audio_iter)
                seg_audio = AudioSegment.from_mp3(io.BytesIO(audio_bytes))
                ep_audio += seg_audio
                ep_chars += len(seg.text)
            except Exception as e:
                print(f"\n    ERROR [{seg.character}]: {e}")

        # Export episode (pad end so last word doesn't clip)
        if len(ep_audio) > 0:
            ep_audio += AudioSegment.silent(duration=500)
            ep_audio.export(str(out_file), format="mp3")
            duration = len(ep_audio) / 1000
            print(f" {duration:.0f}s, {ep_chars:,} chars")

            # Log credits
            log_credits(project_path, ep_num, "ALL", ep_chars)
            log_to_cost_tracker(project_path, ep_num, ep_chars, len(ep_segs))
            total_chars_generated += ep_chars

            if concat:
                all_audio.append(ep_audio)
                all_audio.append(AudioSegment.silent(duration=episode_ms))
        else:
            print(" (empty — no voice_ids configured?)")

    # Concat mode
    if concat and all_audio:
        combined = all_audio[0]
        for a in all_audio[1:]:
            combined += a
        concat_file = project_path / "audio" / f"{project_path.name}_complete.mp3"
        combined.export(str(concat_file), format="mp3")
        print(f"\nConcatenated: {concat_file} ({len(combined) / 1000:.0f}s)")

    print(f"\nTotal: {total_chars_generated:,} chars generated")


def _generate_qwen3(project_path, config, episodes, dialogue_only, concat, force):
    """Generate MP3 audio files via Qwen3-TTS (local)."""
    try:
        import soundfile as sf
        import numpy as np
        from qwen_tts import Qwen3TTSModel
    except ImportError as e:
        print(f"ERROR: Missing dependency for Qwen3: {e}", file=sys.stderr)
        print("Install: pip install qwen-tts soundfile", file=sys.stderr)
        sys.exit(1)

    fountain = find_fountain_file(project_path, config)
    all_segs = parse_fountain(fountain)
    segs = filter_segments(all_segs, dialogue_only=dialogue_only, episodes=episodes)

    if not segs:
        print("No segments to generate.")
        return

    # Check for empty speakers
    missing = []
    nr_spk = config.get("narrator", {}).get("speaker", "")
    if not nr_spk and not dialogue_only:
        missing.append("narrator")
    for name, ch in config.get("characters", {}).items():
        if not ch.get("speaker", ""):
            missing.append(name)
    pool = config.get("side_voice_pool", [])
    pool_empty = all(not (e.get("speaker", "") if isinstance(e, dict) else e)
                     for e in pool) if pool else True
    if pool_empty:
        missing.append("side_voice_pool (all empty)")

    if missing:
        print(f"WARNING: Missing speakers for: {', '.join(missing)}")
        print("Fill in voice_config.yaml before generating.\n")

    # Load model once
    model_name = config.get("qwen3_model", "Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice")
    print(f"Loading Qwen3 model: {model_name}...")
    model = Qwen3TTSModel.from_pretrained(model_name)
    print("Model loaded.")

    silence_cfg = config.get("silence_ms", {})
    gap_ms = silence_cfg.get("between_segments", 200)
    section_ms = silence_cfg.get("section_break", 500)
    episode_ms = silence_cfg.get("episode_break", 1000)

    # Group by episode
    ep_groups = {}
    for s in segs:
        ep_groups.setdefault(s.episode, []).append(s)

    output_dir = project_path / "audio" / "episodes"
    output_dir.mkdir(parents=True, exist_ok=True)

    all_audio = []
    total_chars_generated = 0

    for ep_num in sorted(ep_groups.keys()):
        ep_segs = ep_groups[ep_num]
        out_file = output_dir / f"ep_{ep_num:03d}.mp3"

        if out_file.exists() and not force:
            print(f"  Ep {ep_num:3d}: exists (use --force to regenerate)")
            if concat:
                all_audio.append(AudioSegment.from_mp3(str(out_file)))
                all_audio.append(AudioSegment.silent(duration=episode_ms))
            continue

        print(f"  Ep {ep_num:3d}: generating ({len(ep_segs)} segments)...",
              end="", flush=True)

        wav_clips = []
        sample_rate = None
        ep_chars = 0
        prev_section = None

        for seg in ep_segs:
            speaker, voice_settings = resolve_voice(seg.character, config)
            if not speaker:
                continue

            # Silence gap (as samples at the detected sample rate)
            if sample_rate is not None:
                if prev_section is not None and seg.section != prev_section:
                    wav_clips.append(np.zeros(int(sample_rate * section_ms / 1000)))
                elif wav_clips:
                    wav_clips.append(np.zeros(int(sample_rate * gap_ms / 1000)))
            prev_section = seg.section

            # TTS call
            try:
                instruct = voice_settings.get("instruct", "")
                wavs, sr = model.generate_custom_voice(
                    text=seg.text,
                    speaker=speaker,
                    language="English",
                    instruct=instruct if instruct else None,
                )
                if sample_rate is None:
                    sample_rate = sr
                wav_clips.append(wavs[0])
                ep_chars += len(seg.text)
            except Exception as e:
                print(f"\n    ERROR [{seg.character}]: {e}")

        # Export episode
        if wav_clips and sample_rate:
            # Add tail silence
            wav_clips.append(np.zeros(int(sample_rate * 0.5)))
            combined_wav = np.concatenate(wav_clips)

            # Write WAV to buffer, convert to MP3 via pydub
            wav_buf = io.BytesIO()
            sf.write(wav_buf, combined_wav, sample_rate, format="WAV")
            wav_buf.seek(0)
            ep_audio = AudioSegment.from_wav(wav_buf)
            ep_audio.export(str(out_file), format="mp3")

            duration = len(ep_audio) / 1000
            print(f" {duration:.0f}s, {ep_chars:,} chars")

            # Log (no credit cost for local model)
            log_credits(project_path, ep_num, "ALL", ep_chars)
            log_to_cost_tracker(project_path, ep_num, ep_chars, len(ep_segs),
                                provider="qwen3")
            total_chars_generated += ep_chars

            if concat:
                all_audio.append(ep_audio)
                all_audio.append(AudioSegment.silent(duration=episode_ms))
        else:
            print(" (empty — no speakers configured?)")

    # Concat mode
    if concat and all_audio:
        combined = all_audio[0]
        for a in all_audio[1:]:
            combined += a
        concat_file = project_path / "audio" / f"{project_path.name}_complete.mp3"
        combined.export(str(concat_file), format="mp3")
        print(f"\nConcatenated: {concat_file} ({len(combined) / 1000:.0f}s)")

    print(f"\nTotal: {total_chars_generated:,} chars generated")


def _generate_qwen3_clone(project_path, config, episodes, dialogue_only, concat, force):
    """Generate MP3 audio via Qwen3-TTS ICL (in-context learning) voice cloning.

    Uses the Base 1.7B model in ICL mode — the model conditions on the full
    reference audio + transcript (Rainbow Passage), preserving identity far
    better than x-vector-only mode.

    Emotional direction comes from parentheticals prepended to the spoken text
    in parentheses, e.g. "(tense, clipped) Where are they?" — the Base model's
    natural text understanding adjusts prosody accordingly.

    voice_config.yaml schema for qwen3_clone:
        engine: qwen3_clone
        voice_refs: "/abs/path/to/voice_refs"  # dir containing ref_*.wav files
        narrator:
          ref_wav: ref_narrator.wav     # relative to voice_refs dir
        characters:
          JINX:
            ref_wav: ref_jinx.wav
            instruct: "..."             # fallback parenthetical if none in script
    """
    try:
        import numpy as np
        import torch
        import soundfile as sf
        from qwen_tts import Qwen3TTSModel, VoiceClonePromptItem
    except ImportError as e:
        print(f"ERROR: Missing dependency: {e}", file=sys.stderr)
        print("Install: python3.11 -m pip install qwen-tts soundfile", file=sys.stderr)
        sys.exit(1)

    BASE_MODEL = "Qwen/Qwen3-TTS-12Hz-1.7B-Base"

    # Rainbow Passage — the transcript of all voice reference WAVs.
    # Required for ICL mode so the model can align ref audio with ref text.
    RAINBOW_PASSAGE = (
        "When the sunlight strikes raindrops in the air, they act like a prism and form a rainbow. "
        "The rainbow is a division of white light into many beautiful colors. "
        "These take the shape of a long round arch, with its path high above, "
        "and its two ends apparently beyond the horizon. "
        "There is, according to legend, a boiling pot of gold at one end. "
        "People look but no one ever finds it. "
        "When a man looks for something beyond his reach, "
        "his friends say he is looking for the pot of gold at the end of the rainbow."
    )

    # Resolve voice_refs directory
    refs_dir_raw = config.get("voice_refs", "")
    if refs_dir_raw:
        refs_dir = Path(refs_dir_raw)
    else:
        # Default: recoil/voice_refs/ (two levels up from project/audio/)
        refs_dir = project_path.parent / "voice_refs"
    if not refs_dir.exists():
        print(f"ERROR: voice_refs dir not found: {refs_dir}", file=sys.stderr)
        print("Run generate_voice_refs.py first.", file=sys.stderr)
        sys.exit(1)

    fountain = find_fountain_file(project_path, config)
    all_segs = parse_fountain(fountain)
    segs = filter_segments(all_segs, dialogue_only=dialogue_only, episodes=episodes)

    if not segs:
        print("No segments to generate.")
        return

    # Build character → (ref_wav, default_instruct) lookup
    def _char_config(char):
        if char == "NARRATOR":
            nr = config.get("narrator", {})
        else:
            characters = config.get("characters", {})
            if char in characters:
                nr = characters[char]
            else:
                # Side voice pool — pick by hash
                pool = config.get("side_voice_pool", [])
                if pool:
                    idx = int(hashlib.md5(char.encode()).hexdigest(), 16) % len(pool)
                    nr = pool[idx] if isinstance(pool[idx], dict) else {}
                else:
                    nr = {}
        ref_wav_name = nr.get("ref_wav", "")
        ref_wav = refs_dir / ref_wav_name if ref_wav_name else None
        instruct = nr.get("instruct", "")
        return ref_wav, instruct

    # Load Base model (ICL mode for voice cloning)
    print(f"Loading Base model: {BASE_MODEL}...")
    base_model = Qwen3TTSModel.from_pretrained(BASE_MODEL)
    print("Model loaded.\n")

    # Pre-build ICL voice clone prompts for all characters appearing in segs.
    # ICL mode passes the full reference audio context (speech codes + speaker
    # embedding + transcript) to the model, preserving identity far better
    # than x-vector-only mode.
    # Cache as .icl.npz files next to the ref WAVs to avoid re-extracting.
    chars_needed = set(s.character for s in segs)
    identity = {}
    print("Loading ICL identity prompts...")
    for char in sorted(chars_needed):
        ref_wav, _ = _char_config(char)
        if ref_wav and ref_wav.exists():
            cache_path = ref_wav.with_suffix(".icl.npz")
            wav_mtime = ref_wav.stat().st_mtime
            if cache_path.exists() and cache_path.stat().st_mtime >= wav_mtime:
                print(f"  {char}: {ref_wav.name} (cached ICL)", flush=True)
                cached = np.load(str(cache_path), allow_pickle=True)
                embedding = torch.tensor(cached["embedding"])
                ref_code = torch.tensor(cached["ref_code"])
                identity[char] = [VoiceClonePromptItem(
                    ref_code=ref_code,
                    ref_spk_embedding=embedding,
                    x_vector_only_mode=False,
                    icl_mode=True,
                    ref_text=RAINBOW_PASSAGE,
                )]
            else:
                print(f"  {char}: {ref_wav.name} (extracting ICL)...", end="", flush=True)
                identity[char] = base_model.create_voice_clone_prompt(
                    ref_audio=str(ref_wav),
                    ref_text=RAINBOW_PASSAGE,
                    x_vector_only_mode=False,
                )
                # Cache to disk
                item = identity[char][0]
                emb_np = item.ref_spk_embedding.cpu().numpy() if hasattr(item.ref_spk_embedding, 'cpu') else np.array(item.ref_spk_embedding)
                rc_np = item.ref_code.cpu().numpy() if hasattr(item.ref_code, 'cpu') else np.array(item.ref_code)
                np.savez(str(cache_path), embedding=emb_np, ref_code=rc_np)
                print(f" done (cached → {cache_path.name})")
        else:
            print(f"  {char}: WARNING — ref_wav not found ({ref_wav}), skipping")

    silence_cfg = config.get("silence_ms", {})
    gap_ms     = silence_cfg.get("between_segments", 200)
    section_ms = silence_cfg.get("section_break", 500)
    episode_ms = silence_cfg.get("episode_break", 1000)

    ep_groups = {}
    for s in segs:
        ep_groups.setdefault(s.episode, []).append(s)

    output_dir = project_path / "audio" / "episodes"
    output_dir.mkdir(parents=True, exist_ok=True)

    all_audio = []
    total_chars = 0

    for ep_num in sorted(ep_groups.keys()):
        ep_segs = ep_groups[ep_num]
        out_file = output_dir / f"ep_{ep_num:03d}.mp3"

        if out_file.exists() and not force:
            print(f"  Ep {ep_num:3d}: exists (use --force to regenerate)")
            if concat:
                all_audio.append(AudioSegment.from_mp3(str(out_file)))
                all_audio.append(AudioSegment.silent(duration=episode_ms))
            continue

        print(f"\n  Ep {ep_num:3d}: {len(ep_segs)} segments")

        wav_clips = []
        sample_rate = None
        ep_chars = 0
        prev_section = None

        for seg in ep_segs:
            if seg.character not in identity:
                continue

            ref_wav, default_instruct = _char_config(seg.character)
            instruct = seg.instruct or default_instruct

            label = f"[{seg.character:8s}]"
            print(f"    {label} {seg.text[:55]}...")
            if instruct:
                print(f"               (direction: {instruct[:65]})")

            # Silence gap
            if sample_rate is not None:
                gap_s = section_ms if (prev_section and seg.section != prev_section) else gap_ms
                wav_clips.append(np.zeros(int(sample_rate * gap_s / 1000)))
            prev_section = seg.section

            try:
                # ICL voice clone — identity locked via ref_code + ref_text
                final_wavs, final_sr = base_model.generate_voice_clone(
                    text=seg.text,
                    language="English",
                    voice_clone_prompt=identity[seg.character],
                )
                if sample_rate is None:
                    sample_rate = final_sr

                wav_clips.append(final_wavs[0])
                ep_chars += len(seg.text)
                print(f"               → {len(final_wavs[0])/final_sr:.1f}s")

            except Exception as e:
                print(f"               ERROR: {e}")

        # Export episode
        if wav_clips and sample_rate:
            wav_clips.append(np.zeros(int(sample_rate * 0.5)))
            combined_wav = np.concatenate(wav_clips)
            wav_buf = io.BytesIO()
            sf.write(wav_buf, combined_wav, sample_rate, format="WAV")
            wav_buf.seek(0)
            ep_audio = AudioSegment.from_wav(wav_buf)
            ep_audio.export(str(out_file), format="mp3")
            duration = len(ep_audio) / 1000
            print(f"\n  Ep {ep_num:3d}: done — {duration:.0f}s, {ep_chars:,} chars → {out_file.name}")
            log_credits(project_path, ep_num, "ALL", ep_chars)
            log_to_cost_tracker(project_path, ep_num, ep_chars, len(ep_segs),
                                provider="qwen3_clone")
            total_chars += ep_chars
            if concat:
                all_audio.append(ep_audio)
                all_audio.append(AudioSegment.silent(duration=episode_ms))
        else:
            print(f"  Ep {ep_num:3d}: empty — no speakers configured?")

    if concat and all_audio:
        combined = all_audio[0]
        for a in all_audio[1:]:
            combined += a
        concat_file = project_path / "audio" / f"{project_path.name}_complete.mp3"
        combined.export(str(concat_file), format="mp3")
        print(f"\nConcatenated: {concat_file}")

    print(f"\nTotal: {total_chars:,} chars generated")


def cmd_budget(project_path):
    """Show credit usage this month vs budget."""
    entries = load_credit_log(project_path)
    now = datetime.now(timezone.utc)
    month_prefix = now.strftime("%Y-%m")

    month_entries = [e for e in entries
                     if e.get("timestamp", "").startswith(month_prefix)]
    total = sum(e.get("chars", 0) for e in month_entries)
    remaining = MONTHLY_CHAR_BUDGET - total

    print(f"Budget: {project_path.name} ({month_prefix})")
    print(f"  Used:      {total:>8,} chars")
    print(f"  Budget:    {MONTHLY_CHAR_BUDGET:>8,} chars")
    print(f"  Remaining: {remaining:>8,} chars ({max(0, remaining) / MONTHLY_CHAR_BUDGET * 100:.0f}%)")
    if total > MONTHLY_CHAR_BUDGET:
        over = total - MONTHLY_CHAR_BUDGET
        print(f"  OVER BUDGET by {over:,} chars (${over / 1000 * RATE_PER_1K_CHARS:.2f} overage)")

    if month_entries:
        print(f"\n  Episodes this month:")
        by_ep = {}
        for e in month_entries:
            ep = e.get("episode", "?")
            by_ep[ep] = by_ep.get(ep, 0) + e.get("chars", 0)
        for ep, chars in sorted(by_ep.items()):
            print(f"    Ep {ep}: {chars:,} chars")


def cmd_voice_direct(project_path, episodes):
    """Analyze dialogue parenthetical coverage for voice direction.

    Reports which dialogue lines have/lack emotional parentheticals
    in the target episode range. Used by the /listen skill to know
    which lines need voice direction before TTS generation.
    """
    config = load_config(project_path) if config_path(project_path).exists() else {}
    fountain = find_fountain_file(project_path, config or None)
    print(f"Parsing: {fountain.name}")

    all_segs = parse_fountain(fountain)
    segs = filter_segments(all_segs, episodes=episodes)
    dialogue = [s for s in segs if s.type == "dialogue"]

    if not dialogue:
        print("No dialogue segments found for the given filters.")
        return

    has_dir = [s for s in dialogue if s.instruct]
    needs_dir = [s for s in dialogue if not s.instruct]
    # "Rich" = 20+ words with em-dash (likely already well-directed)
    rich = [s for s in has_dir if len(s.instruct.split()) >= 20 and "\u2014" in s.instruct]
    needs_elab = [s for s in has_dir if s not in rich]

    ep_set = sorted(set(s.episode for s in dialogue))
    print(f"\nEpisodes: {ep_set[0]}-{ep_set[-1]} ({len(ep_set)} episodes)")
    print(f"Dialogue lines: {len(dialogue)}")
    print(f"  Already rich:       {len(rich):3d} (skip)")
    print(f"  Needs elaboration:  {len(needs_elab):3d} (has parenthetical, enhance it)")
    print(f"  Needs new:          {len(needs_dir):3d} (no parenthetical)")
    print(f"  Total to process:   {len(needs_elab) + len(needs_dir):3d}")

    if needs_dir:
        print(f"\nLines needing NEW parentheticals:")
        for s in needs_dir:
            print(f"  Ep {s.episode:3d} [{s.character:12s}] {s.text[:60]}...")

    if needs_elab:
        print(f"\nLines needing ELABORATION:")
        for s in needs_elab:
            print(f"  Ep {s.episode:3d} [{s.character:12s}] ({s.instruct[:40]}...) {s.text[:40]}...")

    print(f"\n---")
    print(f"voice_direct_summary:")
    print(f"  fountain: {fountain}")
    print(f"  total_dialogue: {len(dialogue)}")
    print(f"  needs_new: {len(needs_dir)}")
    print(f"  needs_elaboration: {len(needs_elab)}")
    print(f"  already_rich: {len(rich)}")


def cmd_voices():
    """List available ElevenLabs voices."""
    if ElevenLabsClient is None:
        print("ERROR: pip install elevenlabs", file=sys.stderr)
        sys.exit(1)

    api_key = os.environ.get("ELEVEN_API_KEY")
    if not api_key:
        print("ERROR: Set ELEVEN_API_KEY environment variable.", file=sys.stderr)
        sys.exit(1)

    client = ElevenLabsClient(api_key=api_key)
    response = client.voices.get_all()

    print("Available Voices:\n")
    for voice in response.voices:
        labels = voice.labels or {}
        tags = " | ".join(filter(None, [
            labels.get("gender", ""),
            labels.get("age", ""),
            labels.get("accent", ""),
            labels.get("description", ""),
            labels.get("use_case", ""),
        ]))
        print(f"  {voice.name}")
        print(f"    ID: {voice.voice_id}")
        if tags:
            print(f"    {tags}")
        print()


# ── CLI ──

def parse_episode_filter(args):
    """Parse --episode and --episodes flags into a set of episode numbers."""
    if hasattr(args, "episode") and args.episode:
        return {args.episode}
    if hasattr(args, "episodes") and args.episodes:
        parts = args.episodes.split("-")
        if len(parts) == 2:
            return set(range(int(parts[0]), int(parts[1]) + 1))
        return {int(parts[0])}
    return None


def main():
    parser = argparse.ArgumentParser(
        description="Fountain Screenplay Audio Reader — ElevenLabs TTS")
    parser.add_argument("project", help="Path to project folder")

    sub = parser.add_subparsers(dest="command", required=True)

    sub.add_parser("init", help="Create voice_config.yaml template")

    dr = sub.add_parser("dry-run", help="Char counts, no API calls")
    dr.add_argument("--episode", type=int)
    dr.add_argument("--episodes", help="Range: N-M")
    dr.add_argument("--dialogue-only", action="store_true")

    gen = sub.add_parser("generate", help="Generate MP3 audio files")
    gen.add_argument("--episode", type=int)
    gen.add_argument("--episodes", help="Range: N-M")
    gen.add_argument("--dialogue-only", action="store_true")
    gen.add_argument("--concat", action="store_true")
    gen.add_argument("--force", action="store_true")

    vd = sub.add_parser("voice-direct", help="Analyze parenthetical coverage")
    vd.add_argument("--episode", type=int)
    vd.add_argument("--episodes", help="Range: N-M")

    sub.add_parser("budget", help="Monthly credit usage")
    sub.add_parser("voices", help="List ElevenLabs voices")

    args = parser.parse_args()
    project = Path(args.project).resolve()

    if not project.is_dir():
        print(f"ERROR: Not a directory: {project}", file=sys.stderr)
        sys.exit(1)

    if args.command == "init":
        cmd_init(project)
    elif args.command == "voice-direct":
        cmd_voice_direct(project, parse_episode_filter(args))
    elif args.command == "dry-run":
        cmd_dry_run(project, parse_episode_filter(args), args.dialogue_only)
    elif args.command == "generate":
        cmd_generate(project, parse_episode_filter(args),
                     args.dialogue_only, args.concat, args.force)
    elif args.command == "budget":
        cmd_budget(project)
    elif args.command == "voices":
        cmd_voices()


if __name__ == "__main__":
    main()
