#!/usr/bin/env python3
"""Seedance Chinese Translation A/B Test Harness.

Tests whether Mandarin prompts produce better results than English for
Seedance 2.0, particularly for fabric, spatial, and architectural detail.
ByteDance's training data is Mandarin-dense — this tests whether that
gives Chinese prompts an edge.

Three phases:
  1. translate  — Translate English prompts to Mandarin via Gemini Flash ($0.01)
  2. compare    — Print side-by-side English/Chinese for review (free)
  3. generate   — Generate via both, produce HTML comparison ($$$)

Usage:
    # Phase 1: Translate prompts (costs ~$0.01)
    python3 tools/seedance_zh_ab_test.py --project tartarus --phase translate

    # Phase 2: Review translations side-by-side (free)
    python3 tools/seedance_zh_ab_test.py --project tartarus --phase compare

    # Phase 3: Generate and compare (costs ~$61 for 20 shots x 2 variants)
    python3 tools/seedance_zh_ab_test.py --project tartarus --phase generate

    # Custom shot selection
    python3 tools/seedance_zh_ab_test.py --project tartarus \\
        --shots EP01_SH003,EP01_SH008 --phase translate
"""

import argparse
import json
import os
import random
import re
import sys
import time
from datetime import datetime, timezone
from pathlib import Path

PROJECT_ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(PROJECT_ROOT))

from recoil.pipeline.core.cost import read_cost_from_result  # noqa: E402

from recoil.core.paths import ProjectPaths as CoreProjectPaths  # noqa: E402

# ──────────────────────────────────────────────────────────────────────
# Translation via Gemini Flash
# ──────────────────────────────────────────────────────────────────────

_TRANSLATE_SYSTEM = (
    "You are a professional Chinese translator specializing in "
    "cinematography and visual production terminology.\n\n"
    "Translate the following English video generation prompt into "
    "Mandarin Chinese.\n\n"
    "RULES:\n"
    "1. Preserve ALL technical cinematography terms (shot types, camera "
    "movements, lens specifications) — translate them to standard Chinese "
    "film terminology.\n"
    "2. Preserve ALL @ImageN, @VideoN, @AudioN reference tags exactly as-is.\n"
    "3. Preserve the quality suffix exactly: '4K, Ultra HD' etc. — these are "
    "universal tokens, do not translate.\n"
    "4. Focus on PRECISION for fabric descriptions, spatial relationships, "
    "and architectural details — these are where Chinese excels.\n"
    "5. Maintain the same sentence count and prose structure.\n"
    "6. Do NOT add or remove content — pure translation only.\n"
    "7. Output ONLY the translated prompt, nothing else."
)


def translate_prompt(english_prompt: str) -> str:
    """Translate an English Seedance prompt to Mandarin via Gemini Flash.

    Returns the Chinese translation string.
    Raises RuntimeError if translation fails.
    """
    try:
        import google.generativeai as genai
    except ImportError:
        raise RuntimeError(
            "google-generativeai not installed: pip install google-generativeai"
        )

    api_key = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY")
    if not api_key:
        raise RuntimeError("GEMINI_API_KEY or GOOGLE_API_KEY env var required")

    genai.configure(api_key=api_key)
    model = genai.GenerativeModel(
        "gemini-2.5-flash",
        system_instruction=_TRANSLATE_SYSTEM,
    )

    response = model.generate_content(english_prompt)
    if not response.text:
        raise RuntimeError("Gemini Flash returned empty translation")

    return response.text.strip()


# ──────────────────────────────────────────────────────────────────────
# Shot selection — prefer fabric, spatial, architectural shots
# ──────────────────────────────────────────────────────────────────────

_ZH_PRIORITY_SHOT_TYPES = {"EWS", "WS", "MLS", "MFS"}
_ZH_PRIORITY_KEYWORDS = {
    "fabric", "texture", "silk", "concrete", "marble",
    "wood", "stone", "brick", "glass", "steel",
    "corridor", "alley", "market", "temple", "palace",
}


def _score_shot_for_zh_test(shot: dict) -> int:
    """Score a shot for Chinese translation test relevance."""
    score = 0
    prompt_data = shot.get("prompt_data", {})
    skeleton = prompt_data.get("prompt_skeleton", {})
    shot_type = prompt_data.get("shot_type", "MS")

    if shot_type in _ZH_PRIORITY_SHOT_TYPES:
        score += 3

    for field in ("environment_line", "subject_line"):
        text = skeleton.get(field, "").lower()
        for kw in _ZH_PRIORITY_KEYWORDS:
            if kw in text:
                score += 2

    return score


def select_test_shots(plan_dir: Path, max_shots: int = 20) -> list[dict]:
    """Select the best shots for zh A/B testing from plan files."""
    shots = []
    for plan_file in sorted(plan_dir.glob("ep_*_plan.json")):
        with open(plan_file) as f:
            plan = json.load(f)
        for shot in plan.get("shots", []):
            shot["_source_plan"] = str(plan_file)
            shot["_zh_score"] = _score_shot_for_zh_test(shot)
            shots.append(shot)

    shots.sort(key=lambda s: s["_zh_score"], reverse=True)
    return shots[:max_shots]


# ──────────────────────────────────────────────────────────────────────
# Phase handlers
# ──────────────────────────────────────────────────────────────────────

def phase_translate(shots: list[dict], output_dir: Path):
    """Phase 1: Translate prompts and save pairs."""
    output_dir.mkdir(parents=True, exist_ok=True)
    pairs_file = output_dir / "translation_pairs.json"

    pairs = []
    for i, shot in enumerate(shots):
        shot_id = shot.get("shot_id", f"shot_{i}")
        skeleton = shot.get("prompt_data", {}).get("prompt_skeleton", {})
        en_parts = []
        for field in ("subject_line", "environment_line", "action_line"):
            val = skeleton.get(field, "")
            if val:
                en_parts.append(val.strip().rstrip(".") + ".")
        en_prompt = " ".join(en_parts)

        if not en_prompt.strip():
            print(f"  [{shot_id}] SKIP — empty prompt data")
            continue

        print(
            f"  [{shot_id}] Translating ({len(en_prompt.split())} words)...",
            end=" ", flush=True,
        )
        try:
            zh_prompt = translate_prompt(en_prompt)
            print("OK")
            pairs.append({
                "shot_id": shot_id,
                "en_prompt": en_prompt,
                "zh_prompt": zh_prompt,
                "zh_score": shot.get("_zh_score", 0),
                "shot_type": shot.get("prompt_data", {}).get("shot_type", ""),
                "source_plan": shot.get("_source_plan", ""),
            })
        except Exception as e:
            print(f"FAILED: {e}")

    with open(pairs_file, "w") as f:
        json.dump(pairs, f, indent=2, ensure_ascii=False)
    print(f"\n{len(pairs)} pairs saved to {pairs_file}")


def phase_compare(output_dir: Path):
    """Phase 2: Print side-by-side comparison of translations."""
    pairs_file = output_dir / "translation_pairs.json"
    if not pairs_file.exists():
        print(f"No translation pairs at {pairs_file}. Run --phase translate first.")
        sys.exit(1)

    with open(pairs_file) as f:
        pairs = json.load(f)

    print(f"\n{'='*70}")
    print(f"Chinese Translation Review — {len(pairs)} pairs")
    print(f"{'='*70}\n")

    for pair in pairs:
        print(f"[{pair['shot_id']}] (type: {pair['shot_type']}, "
              f"zh_score: {pair['zh_score']})")
        print(f"  EN: {pair['en_prompt'][:200]}...")
        print(f"  ZH: {pair['zh_prompt'][:200]}...")
        print()


def phase_generate(project: str, shots: list[dict], output_dir: Path):
    """Phase 3: Generate via both languages, produce HTML comparison.

    Uses StepRunner.execute_video() — the unified generation path that flows
    through ExecutionStore so results appear in Dailies.
    """
    pairs_file = output_dir / "translation_pairs.json"
    if not pairs_file.exists():
        print("No translation pairs found. Run --phase translate first.")
        sys.exit(1)

    with open(pairs_file) as f:
        pairs = json.load(f)

    print(f"Generating {len(pairs)} shots x 2 variants (EN + ZH)")
    print(f"Estimated cost: ${len(pairs) * 2 * 5 * 0.3034:.2f}")
    print()

    # Build a lookup from shot_id to shot data for episode extraction
    shot_lookup = {s.get("shot_id"): s for s in shots}

    from recoil.execution.execution_store import ExecutionStore
    from recoil.execution.step_runner import StepRunner
    from recoil.execution.step_types import ProjectPaths
    from recoil.pipeline.core.dispatch import dispatch
    from recoil.pipeline.core.dispatch_context import DispatchContext

    store = ExecutionStore(project)

    results = []
    for pair in pairs:
        shot_id = pair["shot_id"]
        gen_dir = output_dir / "generations" / shot_id
        gen_dir.mkdir(parents=True, exist_ok=True)

        # Extract episode number from shot_id for ProjectPaths
        ep_match = re.match(r"EP(\d+)", shot_id)
        ep_num = int(ep_match.group(1)) if ep_match else 1

        paths = ProjectPaths.for_episode(project, ep_num)
        runner = StepRunner(store=store, paths=paths, episode=ep_num)

        ctx = DispatchContext(
            caller_id="seedance_zh_ab",
            step_runner=runner,
            project=project,
            episode=ep_num,
        )

        # Get duration from plan data if available
        shot_data = shot_lookup.get(shot_id, {})
        duration = shot_data.get("routing_data", {}).get(
            "target_editorial_duration_s", 5
        )
        duration = max(3, duration)

        for lang, prompt in [("en", pair["en_prompt"]),
                             ("zh", pair["zh_prompt"])]:
            # Use a suffixed shot_id so both variants get distinct store entries
            variant_shot_id = f"{shot_id}_{lang}"
            print(f"  [{variant_shot_id}] Generating...", end=" ", flush=True)
            start = time.time()

            try:
                receipt = dispatch(
                    "video_i2v",
                    {
                        "shot_id": variant_shot_id,
                        "prompt": prompt,
                        "model": "seeddance-2.0",
                        "duration": duration,
                        "aspect_ratio": "9:16",
                        "generate_audio": False,
                    },
                    context=ctx,
                )
                result = receipt.run_result
                latency = time.time() - start
                success = result.success
                output_path = result.output_path or ""
                cost_usd = read_cost_from_result(result)
                if success:
                    print(f"OK ({latency:.0f}s, ${cost_usd:.2f})")
                else:
                    print(f"FAIL: {result.error if hasattr(result, 'error') else 'unknown'}")
                results.append({
                    "shot_id": shot_id, "lang": lang,
                    "path": output_path, "success": success,
                    "latency": round(latency, 1),
                    "cost_usd": cost_usd,
                })
            except Exception as e:
                print(f"FAILED: {e}")
                results.append({
                    "shot_id": shot_id, "lang": lang,
                    "success": False, "error": str(e),
                })

    results_file = output_dir / "generation_results.json"
    with open(results_file, "w") as f:
        json.dump(results, f, indent=2)

    _write_comparison_html(pairs, results, output_dir)
    print(f"\nResults: {results_file}")


def _write_comparison_html(
    pairs: list[dict], results: list[dict], output_dir: Path
):
    """Generate blind comparison HTML for review."""
    html_path = output_dir / "zh_ab_comparison.html"
    result_lookup = {}
    for r in results:
        result_lookup[(r["shot_id"], r["lang"])] = r

    rows = []
    for pair in pairs:
        sid = pair["shot_id"]
        en_r = result_lookup.get((sid, "en"), {})
        zh_r = result_lookup.get((sid, "zh"), {})
        if not en_r.get("success") or not zh_r.get("success"):
            continue

        # Randomize A/B order for blind review
        if random.random() > 0.5:
            a_path, b_path = en_r["path"], zh_r["path"]
            a_lang, b_lang = "en", "zh"
        else:
            a_path, b_path = zh_r["path"], en_r["path"]
            a_lang, b_lang = "zh", "en"

        rows.append(
            f'<div class="comparison">'
            f'<h3>{sid} (type: {pair["shot_type"]}, '
            f'zh_score: {pair["zh_score"]})</h3>'
            f'<div class="videos">'
            f'<div class="variant"><h4>Variant A</h4>'
            f'<video controls width="360">'
            f'<source src="{a_path}" type="video/mp4"></video>'
            f'<p class="reveal" data-lang="{a_lang}">Click to reveal</p></div>'
            f'<div class="variant"><h4>Variant B</h4>'
            f'<video controls width="360">'
            f'<source src="{b_path}" type="video/mp4"></video>'
            f'<p class="reveal" data-lang="{b_lang}">Click to reveal</p></div>'
            f'</div></div>'
        )

    html = (
        '<!DOCTYPE html><html><head>'
        '<title>Seedance ZH A/B Test</title>'
        '<style>'
        'body{font-family:-apple-system,sans-serif;margin:2em;'
        'background:#1a1a2e;color:#e0e0e0}'
        '.comparison{margin:2em 0;padding:1em;border:1px solid #333;'
        'border-radius:8px}'
        '.videos{display:flex;gap:2em}'
        '.variant{flex:1}'
        '.reveal{cursor:pointer;color:#888;font-style:italic}'
        '.reveal.shown{color:#4fc3f7;font-weight:bold}'
        '</style>'
        '<script>'
        "document.addEventListener('click',e=>{"
        "if(e.target.classList.contains('reveal')){"
        "e.target.textContent=e.target.dataset.lang.toUpperCase();"
        "e.target.classList.add('shown')}});"
        '</script></head><body>'
        '<h1>Seedance 2.0 — Chinese Translation A/B Test</h1>'
        f'<p>Generated: {datetime.now(timezone.utc).isoformat()}</p>'
        f'<p>{len(rows)} shot comparisons. Click "Click to reveal" '
        'after rating to see which is EN/ZH.</p>'
        + '\n'.join(rows)
        + '</body></html>'
    )

    with open(html_path, "w") as f:
        f.write(html)
    print(f"Comparison HTML: {html_path}")


def main():
    parser = argparse.ArgumentParser(
        description="Seedance Chinese translation A/B test",
    )
    parser.add_argument("--project", required=True, help="Project name")
    parser.add_argument(
        "--phase", required=True,
        choices=["translate", "compare", "generate"],
    )
    parser.add_argument(
        "--shots",
        help="Comma-separated shot IDs (overrides auto-selection)",
    )
    parser.add_argument(
        "--max-shots", type=int, default=20,
        help="Max shots for auto-selection",
    )
    args = parser.parse_args()

    plan_dir = CoreProjectPaths.for_project(args.project).plans_dir
    output_dir = (
        PROJECT_ROOT / "output" / "ab_tests" / "zh_translation" / args.project
    )

    if args.shots:
        shot_ids = {s.strip() for s in args.shots.split(",")}
        all_shots = []
        for plan_file in sorted(plan_dir.glob("ep_*_plan.json")):
            with open(plan_file) as f:
                plan = json.load(f)
            for shot in plan.get("shots", []):
                if shot.get("shot_id") in shot_ids:
                    all_shots.append(shot)
        shots = all_shots
    else:
        shots = select_test_shots(plan_dir, args.max_shots)

    if args.phase == "translate":
        phase_translate(shots, output_dir)
    elif args.phase == "compare":
        phase_compare(output_dir)
    elif args.phase == "generate":
        phase_generate(args.project, shots, output_dir)


if __name__ == "__main__":
    main()
