"""BudgetGuard -- Thread-safe budget enforcement for run.shot / run.episode.

Pre-check (would_exceed) creates a tentative reservation.
Charge commits a previously reserved amount.
Release frees an unrealized reservation (API call failed with no charge).

Thread safety: all mutations under threading.RLock (not asyncio.Lock --
StepRunner is synchronous, threads need OS-level locking).  RLock is used
instead of Lock so that at_warn_threshold can be safely called from within
charge() without deadlock.
"""

from __future__ import annotations

import logging
import threading

logger = logging.getLogger(__name__)


class BudgetExceeded(Exception):
    """Raised when a budget limit is hit."""
    def __init__(self, message: str, spent: float = 0.0, limit: float = 0.0):
        super().__init__(message)
        self.spent = spent
        self.limit = limit


class ProviderError(Exception):
    """Provider-side failure carrying optional `inference_billed_usd`.

    Phase 5F (failed-but-billed accounting): when fal returns
    content_policy_violation / partner_validation_failed, the request was
    billed even though no usable output came back. This exception class is
    the typed surface for catching that case — provider adapters (or
    detection shims) may raise it with `inference_billed_usd > 0` so the
    BudgetGuard can record the spend instead of letting the tally drift.
    """
    def __init__(
        self,
        message: str = "",
        *,
        inference_billed_usd: float = 0.0,
        kind: str | None = None,
    ):
        super().__init__(message)
        self.inference_billed_usd = float(inference_billed_usd)
        self.kind = kind


class BudgetGuard:
    """Thread-safe budget enforcement with reservation semantics."""

    def __init__(
        self,
        limit_usd: float,
        label: str = "",
        per_shot_cap_usd: float | None = None,
    ):
        self._limit = limit_usd
        self._label = label
        self._per_shot_cap = per_shot_cap_usd
        self._spent = 0.0
        self._reserved = 0.0
        self._lock = threading.RLock()
        # Phase 5F telemetry: every charge() is appended here so callers can
        # introspect succeeded vs failed_but_billed events post-hoc.
        self.events: list[dict] = []

    def would_exceed(self, estimated_cost: float) -> bool:
        """Pre-check with tentative reservation.

        Returns True if the estimated cost would exceed the remaining budget.
        If False (within budget), a reservation is created for estimated_cost.
        The caller MUST follow up with either charge() or release().
        """
        with self._lock:
            if self._spent + self._reserved + estimated_cost > self._limit:
                return True
            self._reserved += estimated_cost
            return False

    def would_exceed_per_shot(self, estimated_cost: float) -> bool:
        """Check against the per-shot cap (model-level ceiling)."""
        if self._per_shot_cap is None:
            return False
        return estimated_cost > self._per_shot_cap

    def charge(
        self,
        actual_cost: float,
        reserved_amount: float | None = None,
        *,
        kind: str = "succeeded",
    ) -> None:
        """Commit a previously reserved amount.

        Args:
            actual_cost: The real cost incurred by the API call.
            reserved_amount: The original estimated_cost from would_exceed().
                If provided, the reservation is debited by this amount and the
                difference between reserved and actual is accounted for.
                If None, falls back to debiting actual_cost from reservations.
            kind: Charge classification for telemetry. "succeeded" for normal
                completions; "failed_but_billed" when fal billed an attempt
                that did not produce a usable artifact (content_policy_violation
                etc.). Recorded in self.events alongside running_total.
        """
        with self._lock:
            self._spent += actual_cost
            debit = reserved_amount if reserved_amount is not None else actual_cost
            if self._reserved >= debit:
                self._reserved -= debit
            else:
                # Over-charge or mismatched reservation -- consume all reservation
                self._reserved = 0.0
            self.events.append({
                "amount": float(actual_cost),
                "kind": kind,
                "running_total": float(self._spent),
            })
            if self.at_warn_threshold:
                logger.warning(
                    "Budget %s: %.1f%% spent ($%.2f / $%.2f)",
                    self._label, (self._spent / self._limit) * 100,
                    self._spent, self._limit,
                )

    def release(self, reserved_amount: float) -> None:
        """Release unrealized reservation (API call failed, no charge)."""
        with self._lock:
            self._reserved = max(0.0, self._reserved - reserved_amount)

    @property
    def spent(self) -> float:
        with self._lock:
            return self._spent

    @property
    def reserved(self) -> float:
        with self._lock:
            return self._reserved

    @property
    def remaining(self) -> float:
        with self._lock:
            return self._limit - self._spent - self._reserved

    @property
    def at_warn_threshold(self) -> bool:
        """True when 80%+ of budget consumed."""
        # Note: called from within charge() which holds the RLock,
        # so this must NOT acquire _lock (RLock allows re-entrant access).
        return self._spent >= self._limit * 0.8
