"""Tests for Phase 3 model profile additions: max_cost_per_shot_usd + forbidden strategies.

Verifies:
1. All video models have max_cost_per_shot_usd > 0
2. NBP forbidden_reroll_strategies expanded to 4 entries
3. Specific cost values match the spec
"""

import json
from pathlib import Path

from recoil.core.model_profiles import iter_model_ids


def _load_profiles() -> dict:
    """Load model_profiles.json from the config directory."""
    # Navigate from test file: tests/ -> pipeline/ -> recoil/ -> config/
    config_path = (
        Path(__file__).resolve().parent.parent.parent / "config" / "model_profiles.json"
    )
    if config_path.exists():
        return json.loads(config_path.read_text())
    raise FileNotFoundError(f"model_profiles.json not found at {config_path}")


def test_all_video_models_have_max_cost_per_shot():
    """Every video model must have max_cost_per_shot_usd > 0."""
    profiles = _load_profiles()
    missing = []
    for model_id in iter_model_ids(profiles):
        profile = profiles[model_id]
        if profile.get("modality") == "video":
            cost = profile.get("max_cost_per_shot_usd")
            if cost is None or cost <= 0:
                missing.append(model_id)
    assert not missing, f"Video models missing max_cost_per_shot_usd > 0: {missing}"


def test_all_image_models_have_max_cost_per_shot():
    """Every image model must have max_cost_per_shot_usd > 0."""
    profiles = _load_profiles()
    missing = []
    for model_id, profile in profiles.items():
        if model_id == "schema_version" or model_id.startswith("_"):
            continue
        if profile.get("modality") == "image":
            cost = profile.get("max_cost_per_shot_usd")
            if cost is None or cost <= 0:
                missing.append(model_id)
    assert not missing, f"Image models missing max_cost_per_shot_usd > 0: {missing}"


def test_nbp_forbidden_reroll_strategies_expanded():
    """NBP must have 4 forbidden strategies per Phase 3 spec."""
    profiles = _load_profiles()
    nbp = profiles["gemini-3-pro-image-preview"]
    forbidden = set(nbp["forbidden_reroll_strategies"])
    expected = {
        "style_transfer",
        "prompt_append",
        "prompt_rewrite",
        "negative_prompt_injection",
    }
    assert forbidden == expected, (
        f"NBP forbidden_reroll_strategies: got {forbidden}, expected {expected}"
    )


def test_specific_cost_values():
    """Verify specific max_cost_per_shot_usd values from the spec."""
    profiles = _load_profiles()
    expected = {
        "gemini-3-pro-image-preview": 1.50,
        "seeddance-2.0": 4.55,
        "kling-v3": 2.00,
        "kling-v3-direct": 2.00,
        "kling-2.5": 2.00,
        "kling-o3": 2.00,
        "kling-o3-direct": 2.00,
        "veo-3.1": 6.00,
        "veo-3.1-generate-preview": 6.00,
    }
    mismatches = []
    for model_id, expected_cost in expected.items():
        actual = profiles[model_id].get("max_cost_per_shot_usd")
        if actual != expected_cost:
            mismatches.append(f"{model_id}: expected {expected_cost}, got {actual}")
    assert not mismatches, "Cost mismatches:\n" + "\n".join(mismatches)
