"""Lightweight cost-rate helper for Phase 5F pre-flight budget gating.

Episode runner's pre-flight check needs ONE number — the per-second rate
for a model — so it can compute an estimated charge BEFORE dispatching.
The canonical `compute_cost` in `recoil/pipeline/core/cost.py` already
handles full billing-unit dispatch, but pulling that in requires a real
`RunResult` shape. For pre-flight we only need the per-second rate.

This module reads `recoil/config/model_profiles.json` once at import time
(profiles are stable across a run) and exposes `cost_per_second(model_id)`.

If `cost_per_second` is tier-keyed in the profile (dict of tier -> rate),
this helper returns the MAX rate across tiers so pre-flight stays
conservative (we'd rather over-estimate and gate early than under-estimate
and overshoot the budget — the very bug Phase 5F fixes).
"""

from __future__ import annotations

import json
from pathlib import Path

# recoil/pipeline/_lib/cost.py  ->  parents[2] = recoil/  ->  recoil/config/...
_PROFILES_PATH = (
    Path(__file__).resolve().parents[2] / "config" / "model_profiles.json"
)

try:
    _PROFILES: dict = json.loads(_PROFILES_PATH.read_text())
except (FileNotFoundError, OSError, json.JSONDecodeError):
    _PROFILES = {}


def cost_per_second(model_id: str) -> float:
    """Return the per-second USD cost for `model_id`.

    Returns 0.0 if the model is unknown or the profile is missing a
    per-second rate (image / per-1k-chars models). Pre-flight callers can
    treat 0.0 as "no per-second gate possible — fall through to legacy
    behavior" without raising.
    """
    profile = _PROFILES.get(model_id) or {}
    rate = profile.get("cost_per_second")
    if isinstance(rate, dict):
        # Tier-keyed (e.g. {"standard": 0.30, "fast": 0.24}). Be conservative:
        # return the max so pre-flight gates on the worst case.
        numeric = [v for v in rate.values() if isinstance(v, (int, float))]
        if numeric:
            return float(max(numeric))
        return 0.0
    if isinstance(rate, (int, float)):
        return float(rate)
    # Last-ditch legacy keys (seedance fast/720p variants).
    for legacy in ("cost_per_second_fast", "cost_per_second_fast_720p"):
        v = profile.get(legacy)
        if isinstance(v, (int, float)):
            return float(v)
    return 0.0


__all__ = ["cost_per_second"]
