# ==============================================================================
# PORTED FROM STARSEND: lib/model_profiles.py
# DATE: 2026-03-29
# NOTE: For historical git blame prior to this date, see the starsend repository.
# ==============================================================================
"""
model_profiles.py — Load and query per-model capabilities from model_profiles.json.

Provides runtime access to model capabilities: supported aspect ratios,
max reference images, cost, API pattern, etc.
"""

import json
import logging
from pathlib import Path
from typing import Optional

from recoil.core.exceptions import CostMissingError
from recoil.core.paths import CONFIG_DIR

logger = logging.getLogger(__name__)

_PROFILES_PATH = CONFIG_DIR / "model_profiles.json"

_profiles: Optional[dict] = None
_CROSS_VALIDATED: bool = False


_METADATA_KEYS: frozenset[str] = frozenset({"schema_version"})


def iter_model_ids(profiles: dict) -> list[str]:
    """Iterate model IDs in a model_profiles-shaped dict.

    Filters out `schema_version` and `_`-prefixed keys. Use for any config
    keyed by model_id with single-underscore metadata convention
    (model_profiles.json, model_roles.json, prompt_bible).
    """
    return [
        k for k in profiles
        if k not in _METADATA_KEYS and not k.startswith("_")
    ]


def iter_strategy_model_ids(strategy: dict) -> list[str]:
    """Iterate model IDs in a provider_strategy-shaped dict.

    Filters out `schema_version` and `__`-prefixed keys. The `__`-prefix
    convention covers legacy `__version__` and any future `__internal`
    markers (provider_strategy.json predates the single-underscore
    convention used by model_profiles).
    """
    return [
        k for k in strategy
        if k not in _METADATA_KEYS and not k.startswith("__")
    ]


def load() -> dict:
    """Load model profiles from config file (cached after first load).

    Note: `schema_version` is preserved (not stripped) because the spec
    contract requires `load().get('schema_version') == 1`. Callers that
    iterate model entries should use `iter_model_ids()` to skip metadata.
    """
    global _profiles
    if _profiles is None:
        from recoil.core.config_schema import validate_and_load
        _profiles = validate_and_load(_PROFILES_PATH, "model_profiles")
        _run_cross_config_validation_once()
    return _profiles


# Phase D — MF-4: promoted from private to public.
# pipeline/lib/preflight.py imported `_load` across the package boundary to
# share the cached profile dict. Promotion makes the loader part of the
# explicit contract. Underscore alias retained for one-cycle backwards
# compat — in-module callers below (and test fixtures) continue to work.
_load = load


def _run_cross_config_validation_once() -> None:
    """Run validate_cross_config once at first model_profiles load.

    CROSS_CONFIG_VALIDATION_DEFERRED — gated on RECOIL_ENFORCE_CROSS_CONFIG=1
    until live-config drift is resolved (see Phase 6 report).
    """
    global _CROSS_VALIDATED
    if _CROSS_VALIDATED:
        return

    import os
    if os.environ.get("RECOIL_ENFORCE_CROSS_CONFIG") != "1":
        _CROSS_VALIDATED = True
        return

    from recoil.core.model_profiles_validate import validate_cross_config
    from recoil.core.config_schema import validate_and_load

    try:
        provider_strategy = validate_and_load(
            CONFIG_DIR / "provider_strategy.json", "provider_strategy"
        )
        pipeline_config = validate_and_load(
            CONFIG_DIR / "pipeline_config.json", "pipeline_config"
        )
        model_roles = validate_and_load(
            CONFIG_DIR / "model_roles.json", "model_roles"
        )
    except FileNotFoundError:
        # Tests may mock CONFIG_DIR with some files absent.
        _CROSS_VALIDATED = True
        return

    validate_cross_config(
        model_profiles=_profiles or {},
        provider_strategy=provider_strategy,
        pipeline_config=pipeline_config,
        model_roles=model_roles,
    )
    _CROSS_VALIDATED = True


def get_profile(model_id: str) -> dict:
    """Get the full profile dict for a model ID.

    Raises KeyError if model not found.
    """
    profiles = _load()
    if model_id not in profiles:
        raise KeyError(
            f"Unknown model: {model_id}. Available: {', '.join(profiles.keys())}"
        )
    return profiles[model_id]


def get_cost(model_id: str, *, allow_missing: bool = False) -> float:
    """Get cost per image (or per second for video models).

    Fails loud (``CostMissingError``) when the profile carries neither
    ``cost_per_image`` nor ``cost_per_second`` — a silent ``0.0`` here
    undercounts spend (REC-216). An explicit ``cost_*: 0.0`` is a legitimate
    billing-zero and is returned as-is; only an ABSENT cost raises.

    Telemetry/display callers that tolerate a missing cost pass
    ``allow_missing=True`` (returns ``0.0`` with a WARNING).
    """
    p = get_profile(model_id)
    cost = p.get("cost_per_image", p.get("cost_per_second"))
    if cost is None:
        if allow_missing:
            logger.warning(
                "get_cost(%s): no cost_per_image/cost_per_second; "
                "returning 0.0 (allow_missing)",
                model_id,
            )
            return 0.0
        raise CostMissingError(
            result_id=model_id, source=f"model_profiles.get_cost({model_id})"
        )
    return float(cost)


def get_provider_cost_per_second(
    model_id: str,
    provider_id: Optional[str] = None,
    tier: Optional[str] = None,
) -> float:
    """Get the per-second cost for a model via its active provider and tier.

    Resolves provider and tier using the same primary/primary_tier path that
    the runtime registry uses (mirroring registry.resolve_adapter):
      - provider: ``provider_id`` arg if given, else
        ``provider_strategy[model_id]["primary"]``.
      - tier:     ``tier`` arg if given, else
        ``provider_strategy[model_id].get("primary_tier") or "default"``.

    Returns the ``cost_per_second`` from
    ``profile["providers"][provider]["tiers"][tier]``.

    Fallback: if the model profile has no ``providers`` block, falls back to
    the top-level ``cost_per_second`` (legacy flat field).

    Raises:
        KeyError: if provider or tier is absent from the providers block.
        ValueError: if the resolved tier entry has no ``cost_per_second``.
    """
    from recoil.core.config_schema import validate_and_load

    # Resolve provider and tier via provider_strategy (same path as runtime).
    strategy_path = CONFIG_DIR / "provider_strategy.json"
    strategy = validate_and_load(strategy_path, "provider_strategy")

    if provider_id is None:
        strategy_entry = strategy.get(model_id, {})
        provider_id = strategy_entry.get("primary", "default")

    if tier is None:
        strategy_entry = strategy.get(model_id, {})
        tier = strategy_entry.get("primary_tier") or "default"

    profile = get_profile(model_id)
    providers_block = profile.get("providers")

    # Fallback: no providers block — use legacy top-level cost_per_second.
    if not providers_block:
        return float(profile.get("cost_per_second", 0.0))

    if provider_id not in providers_block:
        raise KeyError(
            f"Model '{model_id}' providers block has no entry for provider "
            f"'{provider_id}'. Available: {list(providers_block.keys())}"
        )

    provider_entry = providers_block[provider_id]
    tiers = provider_entry.get("tiers", {})

    if tier not in tiers:
        raise KeyError(
            f"Model '{model_id}' provider '{provider_id}' has no tier "
            f"'{tier}'. Available: {list(tiers.keys())}"
        )

    tier_entry = tiers[tier]
    if "cost_per_second" not in tier_entry:
        raise ValueError(
            f"Model '{model_id}' provider '{provider_id}' tier '{tier}' "
            f"has no 'cost_per_second' field."
        )

    return float(tier_entry["cost_per_second"])


def get_max_refs(model_id: str) -> int:
    """Get maximum reference images supported."""
    return get_profile(model_id).get("max_reference_images", 0)


def get_aspect_ratios(model_id: str) -> list[str]:
    """Get supported aspect ratios for a model."""
    return get_profile(model_id).get("supported_aspect_ratios", [])


def supports_inline_refs(model_id: str) -> bool:
    """Check if model supports inline reference images (Gemini pattern)."""
    return get_profile(model_id).get("supports_inline_refs", False)


def get_api_pattern(model_id: str) -> str:
    """Get the API integration pattern ('genai_inline' or 'upload_bundle')."""
    return get_profile(model_id).get("api_pattern", "upload_bundle")


def list_models() -> list[str]:
    """List all available model IDs (skips top-level metadata keys)."""
    return iter_model_ids(_load())


def get_modality(model_id: str) -> str:
    """Returns 'image' or 'video'."""
    return get_profile(model_id).get("modality", "image")


def get_default_duration(model_id: str) -> int:
    """Default video duration in seconds. Returns 0 for image models."""
    return get_profile(model_id).get("max_duration_seconds", 0)


def get_fallback_model(model_id: str) -> Optional[str]:
    """Returns fallback model ID if defined, else None."""
    return get_profile(model_id).get("fallback_model")


def get_critic_override(
    model_id: str,
    critic_name: str,
    param: str = "strictness",
    default: str = "standard",
) -> str:
    """Get a per-model critic override parameter.

    Args:
        model_id: Model ID (e.g. "seedream-v4.5").
        critic_name: Critic name (e.g. "anatomy", "identity_drift").
        param: Parameter name within the override dict (default: "strictness").
        default: Default value if no override exists.

    Returns:
        The override value, or `default` if no override is configured
        for this model+critic+param combination.
    """
    try:
        profile = get_profile(model_id)
    except KeyError:
        return default
    overrides = profile.get("critic_overrides", {})
    critic_cfg = overrides.get(critic_name, {})
    return critic_cfg.get(param, default)


# Strictness level thresholds -- used by critics to convert categorical
# levels to numeric thresholds. Each critic defines its own mapping.
STRICTNESS_LEVELS = {
    "strict": 0,
    "standard": 1,
    "relaxed": 2,
}


def supports_audio(model_id: str) -> bool:
    """Check if model supports audio generation."""
    return get_profile(model_id).get("supports_audio", False)


def supports_multi_shot(model_id: str) -> bool:
    """Check if model supports multi-shot scene batching."""
    return get_profile(model_id).get("supports_multi_shot", False)


def supports_start_end_frame(model_id: str) -> bool:
    """Check if model supports start+end frame I2V control."""
    p = get_profile(model_id)
    return p.get("supports_start_frame", False) and p.get("supports_end_frame", False)


def get_coverage_mode(
    model_id: str,
    pass_type: str,
    has_start_frame: bool,
) -> str:
    """Look up the preferred generation mode for a coverage pass.

    Reads model_profiles.json `<model>.coverage_mode_preferences` — each
    video model declares its own mode per pass-type scenario. This keeps
    mode-selection out of coverage_planner.py and lets a model swap
    automatically carry the right modes with it.

    Args:
        model_id: Bible model key (e.g. "seeddance-2.0").
        pass_type: "character" or "env".
        has_start_frame: True when a previs start frame exists for the pass.

    Returns:
        Mode string ("i2v", "r2v", or "t2v") declared by the model profile.

    Raises:
        KeyError: If the model has no `coverage_mode_preferences` block, or
            the resolved key is missing. Add the field to model_profiles.json
            for any video model referenced in pipeline_config.json
            coverage_strategy.model_routing.
    """
    if pass_type == "env":
        key = "env_with_frame" if has_start_frame else "env_without_frame"
    else:
        key = "character_pass"

    profile = get_profile(model_id)
    prefs = profile.get("coverage_mode_preferences")
    if not prefs:
        raise KeyError(
            f"Model '{model_id}' has no `coverage_mode_preferences` block in "
            f"model_profiles.json. Required for any video model referenced in "
            f"pipeline_config.json coverage_strategy.model_routing. Add: "
            f'{{"character_pass": "...", "env_with_frame": "...", '
            f'"env_without_frame": "..."}}.'
        )
    if key not in prefs:
        raise KeyError(
            f"Model '{model_id}' coverage_mode_preferences is missing key "
            f"'{key}'. Add it to model_profiles.json."
        )
    return prefs[key]


_warned_missing_min_duration: set[str] = set()


def get_segment_duration_bounds(model_name: str) -> tuple[float, float]:
    """Return (min_duration_seconds, max_duration_seconds) from the model profile.

    If min_duration_seconds is missing from the profile, logs a WARNING once
    per model and falls back to 4.0. Missing max falls back to 15.0.
    """
    profile = get_profile(model_name)
    min_d = profile.get("min_duration_seconds")
    max_d = profile.get("max_duration_seconds", 15.0)
    if min_d is None:
        if model_name not in _warned_missing_min_duration:
            logger.warning(
                "Model profile for %r has no min_duration_seconds; defaulting to 4.0",
                model_name,
            )
            _warned_missing_min_duration.add(model_name)
        min_d = 4.0
    return (float(min_d), float(max_d))


def reload():
    """Force reload profiles from disk (useful after config changes)."""
    global _profiles
    _profiles = None
    _warned_missing_min_duration.clear()
    _load()


# ── Role-to-Model Mapping ─────────────────────────────────────────────

_ROLES_PATH = CONFIG_DIR / "model_roles.json"
_roles: Optional[dict] = None


def _load_roles() -> dict:
    """Load role→model mapping from config file (cached after first load)."""
    global _roles
    if _roles is None:
        from recoil.core.config_schema import validate_and_load
        _roles = validate_and_load(_ROLES_PATH, "model_roles")
    return _roles


def get_model(
    role: str, category: Optional[str] = None, project_dir: str = None
) -> str:
    """Look up a model ID by its role name.

    Args:
        role: The role key (e.g. "production", "flash", "gate_image").
        category: Optional category to search in (e.g. "image", "text", "qc").
                  If omitted, searches all categories for the role.
        project_dir: Optional project directory path. If provided, checks for
                     model_overrides.json in that directory first.

    Returns:
        Model ID string.

    Raises:
        KeyError: If role not found.
    """
    # Check project-level overrides first
    # TENET_6_DEFERRED_TO_PHASE_E: per-project model_overrides.json silent
    # fallback (intentional per data-contracts §override-loaders).
    if project_dir:
        override_path = Path(project_dir) / "model_overrides.json"
        if override_path.exists():
            try:
                overrides = json.loads(override_path.read_text(encoding="utf-8"))
                if category and category in overrides:
                    if (
                        isinstance(overrides[category], dict)
                        and role in overrides[category]
                    ):
                        return overrides[category][role]
                else:
                    for cat_data in overrides.values():
                        if isinstance(cat_data, dict) and role in cat_data:
                            return cat_data[role]
            except (json.JSONDecodeError, OSError):
                pass  # Fall through to global roles

    roles = _load_roles()

    # Direct category lookup
    if category and category in roles:
        cat = roles[category]
        if isinstance(cat, dict) and role in cat:
            return cat[role]
        raise KeyError(
            f"Role '{role}' not found in category '{category}'. "
            f"Available: {', '.join(k for k in cat if not k.startswith('_'))}"
        )

    # Search all categories
    for cat_name, cat_data in roles.items():
        if cat_name.startswith("_"):
            continue
        if isinstance(cat_data, dict) and role in cat_data:
            return cat_data[role]

    raise KeyError(
        f"Role '{role}' not found in any category. "
        f"Available categories: {', '.join(k for k in roles if not k.startswith('_'))}"
    )


def get_all_roles() -> dict:
    """Return the full role→model mapping dict."""
    return _load_roles()


def reload_roles():
    """Force re-read role mapping from disk."""
    global _roles
    _roles = None
    _load_roles()


__all__ = [
    # Public symbols (Phase D — MF-3 + DEBT-9).
    # Loader (private-prefixed but Phase 1 caller-graph captures it — keep
    # exported until callers migrate to a public alias).
    "_load",
    # Iteration helpers.
    "iter_model_ids",
    "iter_strategy_model_ids",
    "list_models",
    # Per-model accessors.
    "get_api_pattern",
    "get_aspect_ratios",
    "get_cost",
    "get_provider_cost_per_second",
    "get_coverage_mode",
    "get_critic_override",
    "get_default_duration",
    "get_fallback_model",
    "get_max_refs",
    "get_modality",
    "get_model",
    "get_profile",
    "get_segment_duration_bounds",
    # Capability checks.
    "supports_audio",
    "supports_inline_refs",
    "supports_multi_shot",
    "supports_start_end_frame",
    # Role mapping.
    "get_all_roles",
    "reload_roles",
    # Cache-control.
    "reload",
    # Constants.
    "STRICTNESS_LEVELS",
]


if __name__ == "__main__":
    profiles = _load()
    print(f"Model profiles loaded: {len(profiles)} models\n")
    for model_id in iter_model_ids(profiles):
        profile = profiles[model_id]
        cost = profile.get("cost_per_image", profile.get("cost_per_second", "N/A"))
        print(f"  {model_id}")
        print(f"    Display: {profile.get('display_name', 'N/A')}")
        print(f"    Provider: {profile.get('provider', 'N/A')}")
        print(f"    Cost: ${cost}")
        print(f"    Max refs: {profile.get('max_reference_images', 'N/A')}")
        print(f"    Aspects: {', '.join(profile.get('supported_aspect_ratios', []))}")
        print(f"    API: {profile.get('api_pattern', 'N/A')}")
        print()

    # Print role mappings
    roles = _load_roles()
    print(
        f"\nModel roles loaded: {sum(len(v) for v in roles.values() if isinstance(v, dict))} roles\n"
    )
    for cat_name, cat_data in roles.items():
        if cat_name.startswith("_"):
            continue
        if isinstance(cat_data, dict):
            print(f"  [{cat_name}]")
            for role, model_id in cat_data.items():
                print(f"    {role}: {model_id}")
            print()
