"""Canonical cost computation + reading helpers.

Phase C of the engine-architectural-audit fix sprint (2026-04-30):
- Collapses 4 compute_cost signatures across modalities into ONE typed
  dispatcher that reads model_profiles.json (via the Phase A canonical
  loader core.model_profiles.get_profile) and routes to the right cost
  formula based on the profile's billing modality.
- Provides ONE canonical reader (read_cost_from_result) that raises
  CostMissingError when cost is absent — replaces the 27 production
  `result.metadata.get("cost_usd", 0.0) or 0.0` sites.
- Provides ONE sanctioned-fallback reader (read_cost_from_result_safe)
  for display/aggregation paths where missing tolerates with WARNING log.

Tenet 6 (Errors Must Be Visible) — codified 2026-04-30: a missing cost
is NOT silently zero. Either it's a known billing-zero case (no API call
made — explicit cost_usd=0.0 stored) or it's a violation (we billed but
forgot to record). The default reader raises; the sanctioned variant
emits a FALLBACK_FIRED log line every time it fires.
"""

from __future__ import annotations

from typing import Any, Optional

from recoil.core.exceptions import CostMissingError  # noqa: F401  # DEPRECATED: Phase E.5 migration
from recoil.pipeline._lib.sanctioned_fallbacks import fire_sanctioned_fallback


# ── Tenet 6: cost-missing exception ─────────────────────────────────────
# Canonical home: recoil/lib/exceptions.py::CostMissingError.
# Re-exported here for one-cycle backward compatibility (Phase E.5).


# ── Canonical reader ────────────────────────────────────────────────────

def read_cost_from_result(result: Any) -> float:
    """Canonical reader for cost_usd from a RunResult.

    Reads from result.metadata["cost_usd"] (the canonical home per CP-4)
    OR result.cost_usd (legacy StepResult attribute). Raises
    CostMissingError if neither is present.
    """
    md = getattr(result, "metadata", None)
    if isinstance(md, dict) and "cost_usd" in md:
        v = md["cost_usd"]
        if v is None:
            raise CostMissingError(
                result_id=getattr(result, "id", None),
                source="metadata.cost_usd is None",
            )
        return float(v)

    if hasattr(result, "cost_usd"):
        v = getattr(result, "cost_usd")
        if v is None:
            raise CostMissingError(
                result_id=getattr(result, "id", None),
                source="attribute.cost_usd is None",
            )
        return float(v)

    raise CostMissingError(
        result_id=getattr(result, "id", None),
        source="runresult-no-cost-anywhere",
    )


def read_cost_from_result_safe(result: Any) -> float:
    """Sanctioned-fallback reader. Returns 0.0 + logs WARNING on missing."""
    try:
        return read_cost_from_result(result)
    except CostMissingError as e:
        fire_sanctioned_fallback(
            "cost_unknown_telemetry_zero",
            source=e.source,
            result_id=e.result_id,
        )
        return 0.0


def read_cost_from_record_safe(record: dict) -> float:
    """Sanctioned-fallback for plain dict records (PassStore, ExecutionStore,
    pipeline orchestrator). Reads `cost_usd` (canonical) OR `cost` (legacy
    orchestrator). Returns 0.0 + logs FALLBACK_FIRED on missing."""
    if not isinstance(record, dict):
        fire_sanctioned_fallback(
            "cost_unknown_telemetry_zero",
            source="record-not-dict",
            record_type=type(record).__name__,
        )
        return 0.0
    v = record.get("cost_usd")
    if v is None:
        # orchestrator dicts use legacy "cost" key
        v = record.get("cost")
    if v is None:
        fire_sanctioned_fallback(
            "cost_unknown_telemetry_zero",
            source="record-missing-cost_usd",
            record_keys=list(record.keys())[:6],
        )
        return 0.0
    try:
        return float(v)
    except (TypeError, ValueError):
        fire_sanctioned_fallback(
            "cost_unknown_telemetry_zero",
            source="record-cost-not-numeric",
            value=repr(v)[:40],
        )
        return 0.0


# ── Unified cost reader (Phase E-facing wrapper) ────────────────────────

def get_cost(result: Any, *, allow_missing: bool = False) -> float:
    """Unified cost reader — wraps read_cost_from_result + _safe variant."""
    if allow_missing:
        return read_cost_from_result_safe(result)
    return read_cost_from_result(result)


# ── Canonical compute dispatcher ────────────────────────────────────────

def compute_cost(
    model_id: str,
    *,
    duration_s: Optional[float] = None,
    char_count: Optional[int] = None,
    token_input_count: Optional[int] = None,
    token_output_count: Optional[int] = None,
    tier: Optional[str] = None,
    profile: Optional[dict] = None,
) -> float:
    """Canonical cost computation. One typed signature, modality-dispatching.

    - billing_unit="per_second" → rate * duration_s (video)
    - billing_unit="per_1k_chars" → rate * (char_count / 1000.0) (audio)
    - billing_unit="per_1k_tokens" → input_rate*input/1k + output_rate*output/1k (eval)
    - billing_unit="flat_per_image" → flat_rate (image)
    """
    if profile is None:
        from recoil.core.model_profiles import get_profile
        profile = get_profile(model_id)

    if profile is None:
        raise CostMissingError(
            result_id=model_id,
            source=f"compute_cost: no profile for model_id={model_id!r}",
        )

    billing_unit = profile.get("billing_unit", "per_second")

    if billing_unit == "per_second":
        rate = _resolve_rate(profile, "cost_per_second", tier)
        if duration_s is None:
            raise ValueError(
                f"compute_cost: duration_s required for {model_id} "
                f"(billing_unit=per_second)"
            )
        return float(rate) * float(duration_s)

    if billing_unit == "per_1k_chars":
        rate = _resolve_rate(profile, "cost_per_1k_chars", tier)
        if char_count is None:
            raise ValueError(
                f"compute_cost: char_count required for {model_id} "
                f"(billing_unit=per_1k_chars)"
            )
        return float(rate) * (float(char_count) / 1000.0)

    if billing_unit == "per_1k_tokens":
        in_rate = _resolve_rate(profile, "cost_per_1k_tokens_input", tier)
        out_rate = _resolve_rate(profile, "cost_per_1k_tokens_output", tier)
        if token_input_count is None or token_output_count is None:
            raise ValueError(
                f"compute_cost: token_input_count + token_output_count "
                f"required for {model_id} (billing_unit=per_1k_tokens)"
            )
        return (
            float(in_rate) * (float(token_input_count) / 1000.0)
            + float(out_rate) * (float(token_output_count) / 1000.0)
        )

    if billing_unit == "flat_per_image":
        rate = _resolve_rate(profile, "image_cost", tier)
        return float(rate)

    raise CostMissingError(
        result_id=model_id,
        source=f"compute_cost: unknown billing_unit={billing_unit!r}",
    )


def _resolve_rate(profile: dict, key: str, tier: Optional[str]) -> float:
    """Resolve a rate from profile, supporting tier-keyed overrides."""
    rate = profile.get(key)
    if rate is None:
        raise CostMissingError(
            result_id=profile.get("model_id", "?"),
            source=f"_resolve_rate: profile missing {key!r}",
        )
    if isinstance(rate, dict):
        if not rate:
            raise CostMissingError(
                result_id=profile.get("model_id", "?"),
                source=f"_resolve_rate: tier-keyed {key!r} is empty",
            )
        if not tier:
            raise ValueError(
                f"_resolve_rate: tier argument required for "
                f"{profile.get('model_id', '?')!r} (tier-keyed pricing on "
                f"{key!r}). Available tiers: {list(rate.keys())}"
            )
        if tier not in rate:
            raise ValueError(
                f"_resolve_rate: tier {tier!r} not found in pricing for "
                f"{profile.get('model_id', '?')!r} on {key!r}. "
                f"Available tiers: {list(rate.keys())}"
            )
        return float(rate[tier])
    return float(rate)


__all__ = [
    "CostMissingError",
    "read_cost_from_result",
    "read_cost_from_result_safe",
    "read_cost_from_record_safe",
    "get_cost",
    "compute_cost",
]
