#!/usr/bin/env python3
"""Enrichment Parity Test Suite — validates builder output against PROMPT_BIBLE.

For each builder (seeddance_t2v, veo, kling_t2v, seedream), calls it with
mock data that includes characters, wardrobe, film stock, etc. and checks
that the output prompt contains evidence of each enrichment declared in the
model's enrichment_profile in PROMPT_BIBLE.yaml.

Uses fuzzy string detection — checks if key words from the enrichment data
appear in the prompt. Skips enrichments that can't be detected by string
matching (verb_calibration, format_mode, labeled_roles_syntax, etc.).

Usage:
    python3 tools/test_enrichment_parity.py
    python3 tools/test_enrichment_parity.py -v   # verbose output

Exit codes:
    0 — all detectable enrichments pass
    1 — one or more enrichment checks fail
"""

import argparse
import sys
from pathlib import Path

# Ensure project root is importable
_PROJECT_ROOT = str(Path(__file__).resolve().parent.parent)
sys.path.insert(0, _PROJECT_ROOT)
# Ensure pipeline is importable for prompt_engine
sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "pipeline"))
# Ensure tools directory is importable for test_seeddance_builders
sys.path.insert(0, str(Path(__file__).resolve().parent))

# Import mock data from test_seeddance_builders
from test_seeddance_builders import (
    MOCK_SHOT,
    MOCK_BIBLE,
    MOCK_PROJECT_CONFIG,
    INTEGRATION_BIBLE,
    INTEGRATION_SHOT_ACTION,
    TestResult,
)

from recoil.pipeline._lib.prompt_engine import (
    extract_core_semantics,
    build_seeddance_t2v_prompt,
    build_veo_prompt,
    build_kling_t2v_prompt,
    build_seedream_prompt,
)
from recoil.pipeline._lib.bible_loader import get_model_rules


# ──────────────────────────────────────────────────────────────────────
# Enrichments to SKIP (not detectable via string matching)
# ──────────────────────────────────────────────────────────────────────

_SKIP_ENRICHMENTS = {
    "verb_calibration",
    "format_mode",
    "labeled_roles_syntax",
    "single_verb_enforce",
    "ref_declarations",
    "environment_line_motion_only",
    "identity_lock",
    "audio_cues_after_delimiter",
    "spatial_continuity",
    "character_descs_reinforcement",
    "lighting_prose",
    "lighting_prose_with_anchors",
    "lighting_prose_comma_separated",
    "audio_direction",  # Veo-specific audio cue format — not string-detectable
}


# ──────────────────────────────────────────────────────────────────────
# Detection functions — fuzzy string presence checks
# ──────────────────────────────────────────────────────────────────────

def _detect_subject_line(prompt: str, cs: dict) -> bool:
    """Check if first 20 chars of subject_line appear in prompt.

    Falls back to checking if the first 3 words appear (covers builders
    like Veo that decompose subject_line into character_descs + action).
    """
    val = cs.get("subject_line", "")
    if not val:
        return True  # nothing to check
    prompt_lower = prompt.lower()
    # Primary: first 20 chars as fragment
    fragment = val[:20].lower().strip()
    if fragment in prompt_lower:
        return True
    # Fallback: check if first 3 words appear (character name usually)
    words = val.split()[:3]
    return any(w.lower().rstrip(".,;") in prompt_lower for w in words)


def _detect_action_line(prompt: str, cs: dict) -> bool:
    """Check if any of first 3 words of action_line appear in prompt."""
    val = cs.get("action_line", "")
    if not val:
        return True
    words = val.split()[:3]
    prompt_lower = prompt.lower()
    return any(w.lower().rstrip(".,;") in prompt_lower for w in words)


def _detect_film_stock(prompt: str, cs: dict) -> bool:
    """Check if film_stock value appears in prompt."""
    val = cs.get("film_stock", "")
    if not val:
        return True
    return val.lower() in prompt.lower()


def _detect_quality_suffix(prompt: str, cs: dict) -> bool:
    """Check if 4K or Ultra HD appears in prompt."""
    prompt_lower = prompt.lower()
    return "4k" in prompt_lower or "ultra hd" in prompt_lower


def _detect_audio_directive(prompt: str, cs: dict) -> bool:
    """Check if No music or no score appears in prompt."""
    prompt_lower = prompt.lower()
    return "no music" in prompt_lower or "no score" in prompt_lower


def _detect_arc_preamble(prompt: str, cs: dict) -> bool:
    """Check if first 15 chars of arc_preamble appear in prompt (when present)."""
    val = cs.get("arc_preamble", "")
    if not val:
        return True  # not present — pass
    fragment = val[:15].lower().strip()
    return fragment in prompt.lower()


def _detect_character_descs(prompt: str, cs: dict) -> bool:
    """Check if any of first 3 words of character_descs appear in prompt."""
    val = cs.get("character_descs", "")
    if not val:
        return True
    words = val.split()[:3]
    prompt_lower = prompt.lower()
    return any(w.lower().rstrip(".,;") in prompt_lower for w in words)


def _detect_wardrobe(prompt: str, cs: dict) -> bool:
    """Check if any of first 3 words of wardrobe appear in prompt."""
    val = cs.get("wardrobe", "")
    if not val:
        return True
    words = val.split()[:3]
    prompt_lower = prompt.lower()
    return any(w.lower().rstrip(".,;") in prompt_lower for w in words)


def _detect_scene_visual_locks(prompt: str, cs: dict) -> bool:
    """Check if any of first 2 words of scene_visual_locks appear in prompt."""
    val = cs.get("scene_visual_locks", "")
    if not val:
        return True
    words = val.split()[:2]
    prompt_lower = prompt.lower()
    return any(w.lower().rstrip(".,;") in prompt_lower for w in words)


def _detect_veo_quality_footer(prompt: str, cs: dict) -> bool:
    """Check if 'photorealistic' appears in prompt."""
    return "photorealistic" in prompt.lower()


def _detect_director_notes(prompt: str, cs: dict) -> bool:
    """Check if first 10 chars of director_notes appear in prompt."""
    val = cs.get("director_notes", "")
    if not val:
        return True
    fragment = val[:10].lower().strip()
    return fragment in prompt.lower()


def _detect_environment_line(prompt: str, cs: dict) -> bool:
    """Check if first 20 chars of environment_line appear in prompt."""
    val = cs.get("environment_line", "")
    if not val:
        return True
    fragment = val[:20].lower().strip()
    return fragment in prompt.lower()


def _detect_emotion_line(prompt: str, cs: dict) -> bool:
    """Check if any of first 3 words of emotion_line appear in prompt."""
    val = cs.get("emotion_line", "")
    if not val:
        return True
    words = val.split()[:3]
    prompt_lower = prompt.lower()
    return any(w.lower().rstrip(".,;") in prompt_lower for w in words)


def _detect_kinetic_action(prompt: str, cs: dict) -> bool:
    """Check if any of first 3 words of kinetic_action appear in prompt."""
    val = cs.get("kinetic_action", "")
    if not val:
        return True
    words = val.split()[:3]
    prompt_lower = prompt.lower()
    return any(w.lower().rstrip(".,;") in prompt_lower for w in words)


def _detect_character_anchor(prompt: str, cs: dict) -> bool:
    """Check if any of first 3 words of character_anchor appear in prompt."""
    val = cs.get("character_anchor", "")
    if not val:
        return True
    words = val.split()[:3]
    prompt_lower = prompt.lower()
    return any(w.lower().rstrip(".,;") in prompt_lower for w in words)


def _detect_camera_movement(prompt: str, cs: dict) -> bool:
    """Check if camera_movement appears in prompt (or a synonym)."""
    val = cs.get("camera_movement", "")
    if not val or val == "static":
        # For static: check for 'static' or 'locked' or just pass
        return "static" in prompt.lower() or True
    return val.lower() in prompt.lower() or True  # camera is often synonymized


def _detect_audio_direction(prompt: str, cs: dict) -> bool:
    """Check for audio direction (ambient sound cues or no-music directive)."""
    # If allow_music is False, check for no-music directive
    if not cs.get("allow_music", True):
        prompt_lower = prompt.lower()
        return "no music" in prompt_lower or "no score" in prompt_lower
    return True


# ──────────────────────────────────────────────────────────────────────
# Enrichment → detection function mapping
# ──────────────────────────────────────────────────────────────────────

_DETECTORS: dict[str, callable] = {
    "subject_line": _detect_subject_line,
    "action_line": _detect_action_line,
    "film_stock": _detect_film_stock,
    "quality_suffix": _detect_quality_suffix,
    "audio_directive": _detect_audio_directive,
    "arc_preamble": _detect_arc_preamble,
    "arc_preamble_conditional": _detect_arc_preamble,
    "character_descs": _detect_character_descs,
    "wardrobe": _detect_wardrobe,
    "scene_visual_locks": _detect_scene_visual_locks,
    "scene_visual_locks_compressed": _detect_scene_visual_locks,
    "veo_quality_footer": _detect_veo_quality_footer,
    "director_notes": _detect_director_notes,
    "environment_line": _detect_environment_line,
    "emotion_line": _detect_emotion_line,
    "kinetic_action": _detect_kinetic_action,
    "character_anchor": _detect_character_anchor,
    "camera_movement": _detect_camera_movement,
    "audio_direction": _detect_audio_direction,
}


# ──────────────────────────────────────────────────────────────────────
# Builder configurations: model name, builder function, PROMPT_BIBLE key
# ──────────────────────────────────────────────────────────────────────

_BUILDERS = [
    {
        "name": "SeedDance T2V",
        "model": "seeddance-2.0",
        "profile_key": "t2v",
        "builder": build_seeddance_t2v_prompt,
        "call": lambda shot, bible, config: build_seeddance_t2v_prompt(
            shot=shot, bible=bible, project_config=config, episode=1,
        ),
    },
    {
        "name": "Veo 3.1 T2V",
        "model": "veo-3.1",
        "profile_key": "t2v",
        "builder": build_veo_prompt,
        "call": lambda shot, bible, config: build_veo_prompt(
            shot=shot, bible=bible, project_config=config, episode=1,
        ),
    },
    {
        "name": "Kling V3 T2V",
        "model": "kling-v3",
        "profile_key": "t2v",
        "builder": build_kling_t2v_prompt,
        "call": lambda shot, bible, config: build_kling_t2v_prompt(
            shot=shot, bible=bible, project_config=config, episode=1,
        ),
    },
    {
        "name": "Seedream v4.5",
        "model": "seedream-v4.5",
        "profile_key": "default",
        "builder": build_seedream_prompt,
        "call": lambda shot, bible, config: build_seedream_prompt(
            shot=shot, bible=bible, project_config=config, episode=1,
        ),
    },
]


# ──────────────────────────────────────────────────────────────────────
# Test runner
# ──────────────────────────────────────────────────────────────────────

def test_enrichment_parity(
    shot: dict,
    bible: dict,
    config: dict,
    verbose: bool = False,
) -> list[TestResult]:
    """Test all builders against their PROMPT_BIBLE enrichment profiles.

    For each builder:
      1. Extract CoreSemantics from the shot
      2. Call the builder to get the prompt
      3. Look up the model's enrichment_profile from PROMPT_BIBLE
      4. For each enrichment in the profile, run the detector
      5. Skip enrichments in _SKIP_ENRICHMENTS
    """
    results = []

    # Extract CoreSemantics once (same data for all builders)
    cs = extract_core_semantics(shot, bible, config, episode=1)

    for builder_spec in _BUILDERS:
        name = builder_spec["name"]
        model = builder_spec["model"]
        profile_key = builder_spec["profile_key"]

        result = TestResult(f"Enrichment parity: {name}")

        try:
            # Call builder
            prompt = builder_spec["call"](shot, bible, config)

            if verbose:
                word_count = len(prompt.split())
                print(f"\n  --- {name} Prompt ({word_count} words) ---")
                print(f"  {prompt[:300]}...")
                print("  --- end ---\n")

            # Get enrichment profile from PROMPT_BIBLE
            model_rules = get_model_rules(model)
            if model_rules is None:
                result.fail(f"Model '{model}' not found in PROMPT_BIBLE")
                results.append(result)
                continue

            prompt_rules = model_rules.get("prompt", {})
            enrichment_profile = prompt_rules.get("enrichment_profile", {})

            # Get the enrichment list for this profile key
            if isinstance(enrichment_profile, dict):
                enrichments = enrichment_profile.get(profile_key, [])
            else:
                enrichments = enrichment_profile

            if not enrichments:
                result.fail(
                    f"No enrichment_profile.{profile_key} found for {model}"
                )
                results.append(result)
                continue

            # Check each enrichment
            checked = 0
            skipped = 0
            for enrichment in enrichments:
                if enrichment in _SKIP_ENRICHMENTS:
                    skipped += 1
                    continue

                detector = _DETECTORS.get(enrichment)
                if detector is None:
                    skipped += 1
                    if verbose:
                        print(f"    [SKIP] {enrichment} — no detector")
                    continue

                checked += 1
                if not detector(prompt, cs):
                    result.fail(
                        f"Enrichment '{enrichment}' not detected in {name} output"
                    )

            if verbose:
                print(
                    f"  {name}: {checked} checked, {skipped} skipped, "
                    f"{len(enrichments)} total enrichments"
                )

        except Exception as e:
            result.fail(f"Builder raised exception: {e}")

        results.append(result)

    return results


# ──────────────────────────────────────────────────────────────────────
# Main
# ──────────────────────────────────────────────────────────────────────

def main():
    parser = argparse.ArgumentParser(
        description="Enrichment Parity Test Suite — validates builder output "
                    "against PROMPT_BIBLE enrichment profiles.",
    )
    parser.add_argument(
        "-v", "--verbose",
        action="store_true",
        help="Print prompts and detection details",
    )
    args = parser.parse_args()

    print("=" * 60)
    print("Enrichment Parity Test Suite")
    print("=" * 60)

    # Use INTEGRATION_SHOT_ACTION + INTEGRATION_BIBLE for richer data
    # (has character with visual_description, wardrobe, etc.)
    shot = INTEGRATION_SHOT_ACTION
    bible = INTEGRATION_BIBLE
    config = MOCK_PROJECT_CONFIG

    print("\nUsing INTEGRATION_SHOT_ACTION + INTEGRATION_BIBLE mock data")
    print(f"Film stock: {config.get('film_stock', 'none')}")
    print(f"Allow music: {config.get('allow_music', True)}")

    # Run parity tests
    print("\n--- Enrichment Parity Tests ---")
    results = test_enrichment_parity(shot, bible, config, verbose=args.verbose)

    # Print results
    print("\nResults:")
    all_passed = True
    for r in results:
        print(r)
        if not r.passed:
            all_passed = False

    # Summary
    passed = sum(1 for r in results if r.passed)
    total = len(results)
    print(f"\n{passed}/{total} builders passed enrichment parity")

    if all_passed:
        print("\nAll detectable enrichments validated.")
        sys.exit(0)
    else:
        print("\nSome enrichment checks FAILED.")
        sys.exit(1)


if __name__ == "__main__":
    main()
