#!/usr/bin/env python3
"""seedance_vs_kling_v2v_ab.py — A/B test Kling O3 V2V vs Seedance V2V.

For each (source_video, prompt) pair, runs:
  1. Kling O3 V2V edit (model=kling-o3, --v2v-endpoint=o3_edit_standard)
  2. Seedance V2V (model=seeddance-2.0, --tier=standard)

Writes outputs under projects/{project}/output/ab_tests/{test_id}/{kling|seedance}/
and produces a comparison.md per (source × prompt) pair with empty evaluation
fields for JT to fill in subjectively (prompt adherence, background PSNR,
lighting integration, subjective notes).

NOT a probe script — invokes dispatch_cli.py subprocess per leg, never bypasses
the unified pipeline.

Cost per pair (5s, standard tier): ~$0.56 Kling + ~$1.51 Seedance = ~$2.07.
"""

from __future__ import annotations

import argparse
import datetime as _dt
import re
import subprocess
import sys
from pathlib import Path
from typing import Optional

_HERE = Path(__file__).resolve().parent
_RECOIL_ROOT = _HERE.parent.parent.parent
if str(_RECOIL_ROOT) not in sys.path:
    sys.path.insert(0, str(_RECOIL_ROOT))

from recoil.core.paths import projects_root  # noqa: E402

DISPATCH_CLI = _HERE / "dispatch_cli.py"

DEFAULT_PROMPTS = [
    (
        "background_swap",
        "[0s-5s] The subject in the reference video remains the same, but the "
        "background is a bustling Tokyo street at night, neon lights, raining.",
    ),
    (
        "object_replacement",
        "[0s-5s] The subject in the reference video is holding a glowing "
        "futuristic datapad instead of a smartphone. The datapad casts blue "
        "light on their face.",
    ),
    (
        "seasonal_change",
        "[0s-5s] The forest in the reference video is covered in deep winter "
        "snow. Heavy snowflakes are falling. The subject is wearing a heavy "
        "winter coat.",
    ),
]

KLING_COST_PER_S = 0.112
SEEDANCE_COST_PER_S = 0.3034

# Hard cap per leg; if dispatch_cli hangs (network stall, fal.ai backpressure)
# we must surface failure rather than block the whole A/B run forever.
LEG_TIMEOUT_S = 1800  # 30 minutes — generous for a 15s generation


def _slugify(text: str, maxlen: int = 40) -> str:
    s = re.sub(r"[^a-z0-9]+", "_", text.lower()).strip("_")
    return s[:maxlen] or "prompt"


def _build_kling_argv(
    *,
    project: str,
    source_video: Path,
    prompt: str,
    image_refs: Optional[str],
    duration: int,
    dry_run: bool,
) -> list[str]:
    argv = [
        sys.executable,
        str(DISPATCH_CLI),
        "--project",
        project,
        "--model",
        "kling-o3",
        "--ref-video",
        str(source_video),
        "--v2v-endpoint",
        "o3_edit_standard",
        "--prompt",
        prompt,
        "--duration",
        str(duration),
    ]
    if image_refs:
        argv += ["--image-refs", image_refs]
    if dry_run:
        argv.append("--dry-run")
    return argv


def _build_seedance_argv(
    *,
    project: str,
    source_video: Path,
    prompt: str,
    image_refs: Optional[str],
    duration: int,
    dry_run: bool,
) -> list[str]:
    argv = [
        sys.executable,
        str(DISPATCH_CLI),
        "--project",
        project,
        "--model",
        "seeddance-2.0",
        "--source-video",
        str(source_video),
        "--prompt",
        prompt,
        "--duration",
        str(duration),
        "--tier",
        "standard",
    ]
    if image_refs:
        argv += ["--image-refs", image_refs]
    if dry_run:
        argv.append("--dry-run")
    return argv


def _run_leg(name: str, argv: list[str]) -> dict:
    print(f"\n=== {name} ===")
    print("Invoking:", " ".join(argv))
    try:
        result = subprocess.run(
            argv,
            capture_output=True,
            text=True,
            timeout=LEG_TIMEOUT_S,
        )
    except subprocess.TimeoutExpired as exc:
        partial_stdout = exc.stdout or ""
        partial_stderr = exc.stderr or ""
        if isinstance(partial_stdout, bytes):
            partial_stdout = partial_stdout.decode("utf-8", errors="replace")
        if isinstance(partial_stderr, bytes):
            partial_stderr = partial_stderr.decode("utf-8", errors="replace")
        print(f"TIMEOUT after {LEG_TIMEOUT_S}s")
        return {
            "name": name,
            "returncode": -1,
            "stdout": partial_stdout,
            "stderr": f"TIMEOUT after {LEG_TIMEOUT_S}s\n{partial_stderr}",
        }
    print(result.stdout)
    if result.stderr:
        print("STDERR:", result.stderr)
    return {
        "name": name,
        "returncode": result.returncode,
        "stdout": result.stdout,
        "stderr": result.stderr,
    }


def _parse_output_path(stdout: str) -> Optional[str]:
    """Pull the output path printed by dispatch_cli's final '[OK] ... -> path ...' line."""
    m = re.search(r"\[OK\][^\n]*->\s*(\S+)", stdout)
    return m.group(1) if m else None


def _parse_cost(stdout: str) -> Optional[float]:
    """Pull the cost from dispatch_cli's '($X.XX)' suffix on the OK line."""
    m = re.search(r"\[OK\][^\n]*\(\$([0-9.]+)\)", stdout)
    return float(m.group(1)) if m else None


def _write_comparison(
    *,
    test_dir: Path,
    prompt_slug: str,
    source_video: Path,
    prompt: str,
    kling_result: Optional[dict],
    seedance_result: Optional[dict],
) -> Path:
    md = test_dir / f"comparison_{prompt_slug}.md"
    lines = [
        f"# A/B Comparison — {prompt_slug}",
        "",
        f"**Date:** {_dt.datetime.now().isoformat(timespec='seconds')}",
        f"**Source video:** `{source_video}`",
        f"**Source size:** {source_video.stat().st_size / 1_048_576:.1f} MB",
        "",
        "## Prompt",
        "",
        f"> {prompt}",
        "",
        "## Kling O3 V2V edit",
        "",
    ]
    if kling_result:
        out = _parse_output_path(kling_result["stdout"])
        cost = _parse_cost(kling_result["stdout"])
        lines += [
            f"- Output: `{out or 'N/A'}`",
            f"- Cost:   ${cost if cost is not None else 'N/A'}",
            f"- Exit:   {kling_result['returncode']}",
        ]
    else:
        lines += ["- SKIPPED"]
    lines += [
        "",
        "## Seedance V2V edit",
        "",
    ]
    if seedance_result:
        out = _parse_output_path(seedance_result["stdout"])
        cost = _parse_cost(seedance_result["stdout"])
        lines += [
            f"- Output: `{out or 'N/A'}`",
            f"- Cost:   ${cost if cost is not None else 'N/A'}",
            f"- Exit:   {seedance_result['returncode']}",
        ]
    else:
        lines += ["- SKIPPED"]
    lines += [
        "",
        "## Evaluation (JT fills in)",
        "",
        "- Prompt adherence (kling):    [ pass / fail ]",
        "- Prompt adherence (seedance): [ pass / fail ]",
        "- Background PSNR (kling):     [ dB ]",
        "- Background PSNR (seedance):  [ dB ]",
        "- Lighting integration (kling):    [ 1-10 ]",
        "- Lighting integration (seedance): [ 1-10 ]",
        "- Subjective notes:",
        "  - ",
        "",
    ]
    md.write_text("\n".join(lines), encoding="utf-8")
    print(f"Wrote {md}")
    return md


def main() -> int:
    p = argparse.ArgumentParser(
        description=(
            "A/B test Kling O3 V2V vs Seedance V2V edit. Invokes "
            "dispatch_cli.py for each leg. Records outputs + costs into "
            "projects/{project}/output/ab_tests/{test_id}/comparison_*.md."
        )
    )
    p.add_argument("--project", required=True)
    p.add_argument("--source-video", required=True)
    p.add_argument(
        "--prompt",
        action="append",
        help=(
            "Prompt for the test. Repeat to run multiple prompts on the same "
            "source. If omitted, runs the 3 default prompts from the consult."
        ),
    )
    p.add_argument(
        "--image-refs",
        default=None,
        help="Optional --image-refs string passed to BOTH legs.",
    )
    p.add_argument("--duration", type=int, default=5)
    p.add_argument(
        "--test-id",
        default=None,
        help="Test directory name. Default = source video basename + timestamp.",
    )
    p.add_argument("--skip-kling", action="store_true")
    p.add_argument("--skip-seedance", action="store_true")
    p.add_argument(
        "--dry-run",
        action="store_true",
        help="Print intended invocations + per-leg cost. Does not submit.",
    )
    args = p.parse_args()

    source_video = Path(args.source_video).expanduser().resolve()
    if not source_video.exists():
        print(f"ERROR: --source-video not found: {source_video}")
        return 1

    if args.skip_kling and args.skip_seedance:
        print("ERROR: cannot skip both legs.")
        return 2

    prompts = (
        [(_slugify(t), t) for t in args.prompt] if args.prompt else DEFAULT_PROMPTS
    )

    test_id = (
        args.test_id
        or f"{source_video.stem}_{_dt.datetime.now().strftime('%Y%m%d_%H%M%S')}"
    )
    test_dir = projects_root() / args.project / "output" / "ab_tests" / test_id
    test_dir.mkdir(parents=True, exist_ok=True)
    print(f"Test dir: {test_dir}")

    # Cost preview (covers --dry-run AND real runs — show total before firing)
    legs_per_prompt = 2 - int(args.skip_kling) - int(args.skip_seedance)
    n_prompts = len(prompts)
    est_total = 0.0
    if not args.skip_kling:
        est_total += KLING_COST_PER_S * args.duration * n_prompts
    if not args.skip_seedance:
        est_total += SEEDANCE_COST_PER_S * args.duration * n_prompts
    print(
        f"Estimated total: ${est_total:.2f} "
        f"({n_prompts} prompts × {legs_per_prompt} legs × {args.duration}s)"
    )

    # Each leg is wrapped so a Kling failure does not skip the Seedance
    # comparison (or vice versa). The whole A/B loop also continues across
    # prompts — one bad prompt does not abort the run.
    for slug, prompt in prompts:
        print(f"\n--- prompt: {slug} ---")
        kling_result = None
        seedance_result = None
        if not args.skip_kling:
            argv = _build_kling_argv(
                project=args.project,
                source_video=source_video,
                prompt=prompt,
                image_refs=args.image_refs,
                duration=args.duration,
                dry_run=args.dry_run,
            )
            try:
                kling_result = _run_leg("Kling O3 V2V", argv)
            except Exception as exc:  # noqa: BLE001
                kling_result = {
                    "name": "Kling O3 V2V",
                    "returncode": -1,
                    "stdout": "",
                    "stderr": f"EXCEPTION: {exc!r}",
                }
        if not args.skip_seedance:
            argv = _build_seedance_argv(
                project=args.project,
                source_video=source_video,
                prompt=prompt,
                image_refs=args.image_refs,
                duration=args.duration,
                dry_run=args.dry_run,
            )
            try:
                seedance_result = _run_leg("Seedance V2V", argv)
            except Exception as exc:  # noqa: BLE001
                seedance_result = {
                    "name": "Seedance V2V",
                    "returncode": -1,
                    "stdout": "",
                    "stderr": f"EXCEPTION: {exc!r}",
                }
        try:
            _write_comparison(
                test_dir=test_dir,
                prompt_slug=slug,
                source_video=source_video,
                prompt=prompt,
                kling_result=kling_result,
                seedance_result=seedance_result,
            )
        except Exception as exc:  # noqa: BLE001
            print(f"WARN: comparison.md write failed for {slug}: {exc!r}")

    print(f"\nDone. Test dir: {test_dir}")
    return 0


if __name__ == "__main__":
    sys.exit(main())
