#!/usr/bin/env python3
"""Quick test: CustomVoice model with instruct for line-by-line direction.

Tests whether Aiden's voice stays consistent while instruct changes delivery.
"""
import numpy as np
import soundfile as sf
from qwen_tts import Qwen3TTSModel

CV_MODEL = "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice"

# Jinx's lines from episode 1 with their parentheticals
LINES = [
    {
        "text": "Panel's corroded. Sixty-forty the wire's live. Seventy-thirty I can strip it before the arc hits my face.",
        "instruct": "bored, clinical, rattling off death odds the way someone reads a grocery list",
    },
    {
        "text": "Daddy needs a new pair of lungs.",
        "instruct": "low, conspiratorial warmth, talking to herself like a gambler whispering to dice",
    },
    {
        "text": "Oh great. It's alive.",
        "instruct": "flat deadpan, barely above a whisper, the sigh of someone watching their luck walk out the door",
    },
    {
        "text": "Can't really talk with the hand situation, Chrome Boy.",
        "instruct": "choking, words squeezed out between compressed breaths, physically fighting for air but still cracking wise",
    },
    # Same line, NO instruct — baseline comparison
    {
        "text": "Can't really talk with the hand situation, Chrome Boy.",
        "instruct": None,
    },
]

print(f"Loading CustomVoice model: {CV_MODEL}...")
model = Qwen3TTSModel.from_pretrained(CV_MODEL)
print("Model loaded.\n")

all_wavs = []
sample_rate = None

for i, line in enumerate(LINES):
    instruct = line["instruct"]
    tag = f"[{i+1}] {'instruct' if instruct else 'NO instruct'}"
    print(f"{tag}: {line['text'][:60]}...")
    if instruct:
        print(f"     direction: {instruct[:70]}")

    kwargs = dict(
        text=line["text"],
        language="English",
        speaker="Aiden",
    )
    if instruct:
        kwargs["instruct"] = instruct

    wavs, sr = model.generate_custom_voice(**kwargs)
    if sample_rate is None:
        sample_rate = sr

    duration = len(wavs[0]) / sr
    print(f"     -> {duration:.1f}s\n")

    all_wavs.append(wavs[0])
    # 1 second gap between lines
    all_wavs.append(np.zeros(int(sr * 1.0)))

# Concatenate and export
combined = np.concatenate(all_wavs)
out_path = "/Users/joeturnerlin/Dropbox/CLAUDE_PROJECTS/projects/tartarus/audio/test_custom_voice_aiden.wav"
sf.write(out_path, combined, sample_rate)
print(f"Output: {out_path}")
print(f"Total: {len(combined)/sample_rate:.1f}s")
