"""Gemini Vision provider adapter (CP-9 Phase 2).

Single synchronous HTTP call to Gemini 3.1 Pro Preview's generateContent
endpoint. Used by EvalNode adapters (CP-9 Phase 3+) to score image / video /
audio artifacts. Output is a JSON-parsed score+reasoning verdict. Retries on
5xx + network blips with exponential backoff (with multiplicative jitter +
optional Retry-After honoring); fail-fast on 401/403/404/400/429.

Public surface:
    score_artifact(...) -> EvalProviderResult
    EvalProviderResult dataclass
    EvalProviderError + subclasses (EvalAuthError, EvalQuotaError,
        EvalPayloadError, EvalRateLimitError, EvalServerError, EvalNetworkError)

Test injection: pass `transport=<callable>`. Default transport is
urllib.request.urlopen wrapped to enforce headers/timeout.

Cost computation: per-1k input + output token rates from
recoil/config/model_profiles.json[<model_id>]. Long-context bucket switches
at promptTokenCount > long_context_threshold_tokens (200_000).

Multimodal: inlineData (base64) for files whose base64-encoded request
payload is <20 MB. Files API path is sketched (build_part returns a
{"fileData": {...}} dict) but uploads to the Files API itself live behind a
hard-coded "deferred" raise — Phase 2 only takes the inline path; the
larger-than-inline path raises EvalPayloadError pointing at Files API as the
follow-up. Phase 2 tests cover both branches: inline success + the size
ceiling raise.
"""

from __future__ import annotations

import base64
import json
import logging
import mimetypes
import os
import random
import socket
import time
import urllib.error
import urllib.request
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable, Literal, Optional

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------

DEFAULT_MODEL_ID = "gemini-3.1-pro-preview"
"""Live API model id (NOT the bare 'gemini-3.1-pro' marketing alias which
404s at the generateContent URL). § 12a item 1."""

PRIMARY_AUTH_ENV_VAR = "GEMINI_API_KEY"
FALLBACK_AUTH_ENV_VAR = "GOOGLE_API_KEY"
DEFAULT_AUTH_ENV_VAR = PRIMARY_AUTH_ENV_VAR  # back-compat alias

DEFAULT_TIMEOUT_S = 120.0
ENDPOINT_TEMPLATE = (
    "https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent"
)

# 20 MB ceiling on the BASE64-encoded request body. Base64 inflates ~33%, so
# raw_bytes > 14_500_000 is the practical fallback trigger. § 12a item 4.
INLINE_DATA_MAX_BYTES = 20 * 1024 * 1024
INLINE_DATA_RAW_FALLBACK_THRESHOLD = 14_500_000

# Long-context bucket switch (Gemini 3.x pricing). § 12a item 3.
LONG_CONTEXT_THRESHOLD_TOKENS_DEFAULT = 200_000

# Retry counts (vendor doc § 8 / audit § 12a item 7).
RETRY_COUNTS = {
    429: 3,   # 1s, 2s, 4s
    500: 3,   # 1s, 2s, 4s
    503: 3,   # 5s, 10s, 20s (longer schedule)
    504: 2,   # 2s, 4s (NOT 3x)
}
NETWORK_RETRY_COUNT = 3  # URLError / socket.timeout: 1s, 2s, 4s

# Backoff base schedules (seconds before jitter).
RETRY_SCHEDULE_DEFAULT = [1.0, 2.0, 4.0]
RETRY_SCHEDULE_503 = [5.0, 10.0, 20.0]
RETRY_SCHEDULE_504 = [2.0, 4.0]

# Retry-After header cap (seconds). § 12a item 9.
RETRY_AFTER_CAP_S = 30.0

# Supported artifact modalities — single source of truth.
SUPPORTED_MODALITIES: tuple[str, ...] = ("image", "video", "audio")
ArtifactModality = Literal["image", "video", "audio"]

# Mime fallbacks per modality (keyed by SUPPORTED_MODALITIES).
_MIME_FALLBACK = {
    "image": "image/png",
    "video": "video/mp4",
    "audio": "audio/mp3",
}

# Predefined warning tokens for EvalProviderResult.raw_metadata["warnings"].
# These are the WIRE CONTRACT — Phase 3+ callers may switch on these strings.
WARNING_SCORE_CLIPPED = "score_clipped"
WARNING_TRUNCATED_MAX_TOKENS = "truncated_max_tokens"
WARNING_MISSING_TOKEN_COUNT = "missing_token_count"

# EvalPayloadError reason marker for MAX_TOKENS + unparseable inner text.
TRUNCATED_UNPARSEABLE_MARKER = "truncated_unparseable"


# ---------------------------------------------------------------------------
# Exception tree
# ---------------------------------------------------------------------------

class EvalProviderError(RuntimeError):
    """Base class for Gemini Vision adapter failures."""


class EvalAuthError(EvalProviderError):
    """401/403 — invalid API key or permission denied. Fail-fast."""


class EvalQuotaError(EvalProviderError):
    """402 — out of credits. Reserved (Gemini does not currently emit 402;
    quota issues come back as 429 → EvalRateLimitError). Kept in the tree
    for Vertex AI fallback or future API change."""


class EvalPayloadError(EvalProviderError):
    """400/404 — bad request body, bad model id, response JSON unparseable,
    or artifact > inline-data ceiling without Files API support yet."""


class EvalRateLimitError(EvalProviderError):
    """429 — rate-limited. Retried per RETRY_COUNTS[429]; raised on
    exhaustion."""


class EvalServerError(EvalProviderError):
    """5xx — retried per RETRY_COUNTS; raised when retries exhausted. Also
    used for SAFETY / OTHER finishReason cases."""


class EvalNetworkError(EvalProviderError):
    """URLError / socket.timeout / TimeoutError. Retried NETWORK_RETRY_COUNT
    times. After exhaustion, raised wrapping the original exception."""


# ---------------------------------------------------------------------------
# Result dataclass — § 12a item 14
# ---------------------------------------------------------------------------

@dataclass
class EvalProviderResult:
    """Adapter-level result. Higher-level eval runners (Phase 4) wrap this
    in RunResult / GenerationReceipt.

    raw_metadata holds adapter + Gemini-passthrough keys:
      - ``warnings``: list[str] — wire-contract tokens (WARNING_SCORE_CLIPPED,
        WARNING_TRUNCATED_MAX_TOKENS, WARNING_MISSING_TOKEN_COUNT).
      - ``judge_id``: echoed caller-supplied id.
      - ``artifact_modality``: "image" | "video" | "audio".
      - ``score_clipped``: dict {original: float} — present only when the
        model returned a score outside [0.0, 1.0].
      - ``finish_reason``: raw Gemini finishReason enum string ("STOP",
        "MAX_TOKENS", etc.) — Gemini-shape passthrough.
      - ``usage``: Gemini usageMetadata dict (camelCase keys:
        ``promptTokenCount`` / ``candidatesTokenCount`` / ``totalTokenCount``)
        — Gemini-shape passthrough; do NOT treat as provider-neutral.
    Phase 3+ callers should normalize provider-shape keys before use.
    """
    score: float
    reasoning: str
    cost_usd: float
    model_used: str
    request_id: Optional[str]
    raw_metadata: dict[str, Any] = field(default_factory=dict)


# ---------------------------------------------------------------------------
# Default transport
# ---------------------------------------------------------------------------

def _default_transport(url: str, *, headers: dict, body: bytes, timeout: float):
    """Wraps urlopen with explicit method=POST. Test code passes a mock
    callable with the same signature."""
    req = urllib.request.Request(url, data=body, headers=headers, method="POST")
    return urllib.request.urlopen(req, timeout=timeout)


# ---------------------------------------------------------------------------
# HTTP error → typed exception
# ---------------------------------------------------------------------------

def _classify_http_error(code: int, body_bytes: bytes) -> EvalProviderError:
    msg = body_bytes.decode("utf-8", errors="replace")[:512]
    if code in (401, 403):
        return EvalAuthError(f"Gemini {code} (auth): {msg}")
    if code == 402:
        return EvalQuotaError(f"Gemini 402 (quota): {msg}")
    if code == 400:
        return EvalPayloadError(f"Gemini 400 (payload): {msg}")
    if code == 404:
        return EvalPayloadError(f"Gemini 404 (not found / bad model id): {msg}")
    if code == 429:
        return EvalRateLimitError(f"Gemini 429 (rate-limited): {msg}")
    if code in (500, 501, 502, 503, 504):
        return EvalServerError(f"Gemini {code}: {msg}")
    return EvalProviderError(f"Gemini {code}: {msg}")


# ---------------------------------------------------------------------------
# Cost compute (§ 12a item 3)
# ---------------------------------------------------------------------------

def _compute_cost(
    model_id: str,
    input_tokens: int,
    output_tokens: int,
    *,
    long_context: bool = False,
) -> float:
    """Best-effort cost estimate. Reads model_profiles.json lazily.

    Per-1k convention. If the profile is missing or a rate is missing, falls
    back to 0.0 silently — Phase 3 guarantees the entry exists in production;
    Phase 2 must not crash test environments where the profile isn't wired.
    """
    try:
        from recoil.core.model_profiles import get_profile
        prof = get_profile(model_id)
        if long_context:
            in_rate = float(
                prof.get("cost_per_1k_input_tokens_long_context") or 0.0
            )
            out_rate = float(
                prof.get("cost_per_1k_output_tokens_long_context") or 0.0
            )
        else:
            in_rate = float(prof.get("cost_per_1k_input_tokens") or 0.0)
            out_rate = float(prof.get("cost_per_1k_output_tokens") or 0.0)
        return (in_rate * input_tokens / 1000.0) + (
            out_rate * output_tokens / 1000.0
        )
    except Exception as e:  # noqa: BLE001 — best-effort
        logger.debug(f"cost compute fell back to 0.0 ({e})")
        return 0.0


def _long_context_threshold(model_id: str) -> int:
    """Look up long_context_threshold_tokens from the profile; fall back
    to LONG_CONTEXT_THRESHOLD_TOKENS_DEFAULT if missing."""
    try:
        from recoil.core.model_profiles import get_profile
        prof = get_profile(model_id)
        return int(
            prof.get("long_context_threshold_tokens")
            or LONG_CONTEXT_THRESHOLD_TOKENS_DEFAULT
        )
    except Exception:  # noqa: BLE001
        return LONG_CONTEXT_THRESHOLD_TOKENS_DEFAULT


# ---------------------------------------------------------------------------
# Auth resolution (§ 12a item 11)
# ---------------------------------------------------------------------------

def _resolve_api_key(env_var: str) -> Optional[str]:
    """Resolve API key with GEMINI_API_KEY > GOOGLE_API_KEY precedence.

    If the caller passes a custom env_var, that's the only one consulted.
    If the caller uses the default (PRIMARY_AUTH_ENV_VAR), GOOGLE_API_KEY
    is consulted as a fallback. Never logs either value.
    """
    key = os.environ.get(env_var)
    if key:
        return key
    if env_var == PRIMARY_AUTH_ENV_VAR:
        return os.environ.get(FALLBACK_AUTH_ENV_VAR)
    return None


# ---------------------------------------------------------------------------
# Multimodal payload construction (§ 12a items 4, 12, 13)
# ---------------------------------------------------------------------------

def _build_part_for_artifact(
    artifact_path: Path,
    modality: str,
) -> dict:
    """Build a Gemini contents.parts entry for an image/video/audio artifact.

    Inline strategy: base64-encode the file (standard, NOT urlsafe — Gemini
    rejects '-' / '_' in inlineData.data). Hard-fail with EvalPayloadError
    if raw bytes exceed INLINE_DATA_RAW_FALLBACK_THRESHOLD — Files API path
    deferred to a later CP. The exception text explicitly mentions Files API
    so callers can wire the upload path themselves if they need it pre-CP-N.
    """
    if not artifact_path.exists():
        raise EvalPayloadError(
            f"artifact_path does not exist: {artifact_path}"
        )

    raw_size = artifact_path.stat().st_size
    if raw_size > INLINE_DATA_RAW_FALLBACK_THRESHOLD:
        raise EvalPayloadError(
            f"artifact size {raw_size} exceeds inline-data fallback threshold "
            f"{INLINE_DATA_RAW_FALLBACK_THRESHOLD} (base64-encoded would "
            f"exceed the 20 MB Gemini ceiling); Files API upload path "
            f"deferred to CP-N+"
        )

    mime, _ = mimetypes.guess_type(str(artifact_path))
    if not mime:
        mime = _MIME_FALLBACK.get(modality, "application/octet-stream")

    file_bytes = artifact_path.read_bytes()
    # § 12a item 12 — STANDARD base64, NOT urlsafe.
    data_b64 = base64.standard_b64encode(file_bytes).decode("ascii")

    # Defensive double-check: encoded size should not exceed the absolute
    # 20 MB ceiling. The raw-bytes prelude almost always catches this, but
    # we belt-and-braces the encoded check.
    if len(data_b64) > INLINE_DATA_MAX_BYTES:
        raise EvalPayloadError(
            f"base64-encoded artifact size {len(data_b64)} exceeds inline-data "
            f"ceiling {INLINE_DATA_MAX_BYTES}; Files API upload path "
            f"deferred to CP-N+"
        )

    # § 12a item 2 — camelCase keys.
    return {"inlineData": {"mimeType": mime, "data": data_b64}}


def _build_part_for_file_uri(file_uri: str, mime_type: str) -> dict:
    """Construct the fileData payload for a Files-API-uploaded artifact.

    § 12a item 13: ``fileUri`` is the Files API ``name`` field
    (``"files/abc-123"``), NOT the ``uri`` field (``"https://..."``). Phase 2
    does not perform the upload itself; this helper exists so Phase 3+ /
    callers who want to wire the upload path can drop it in.
    """
    return {"fileData": {"mimeType": mime_type, "fileUri": file_uri}}


# ---------------------------------------------------------------------------
# Response parsing (§ 12a items 5, 6)
# ---------------------------------------------------------------------------

def _strip_code_fences(text: str) -> str:
    """Strip ```json ... ``` (or bare ```) wrappers if present."""
    s = text.strip()
    if not s.startswith("```"):
        return s
    # Drop the opening fence line.
    after_open = s.split("\n", 1)
    s = after_open[1] if len(after_open) == 2 else ""
    if s.endswith("```"):
        s = s[:-3]
    return s.strip()


def _parse_response(body_bytes: bytes) -> tuple[float, str, dict]:
    """Parse Gemini response body → (score, reasoning, raw_metadata).

    raw_metadata always includes a ``warnings`` list (possibly empty).
    Two-step JSON parse on ``candidates[0].content.parts[0].text``
    (§ 12a item 5). MAX_TOKENS handling per § 12a item 6.

    Raises EvalPayloadError on unrecoverable parse failures and
    EvalServerError on SAFETY / OTHER finishReason.
    """
    warnings: list[str] = []

    try:
        body = json.loads(body_bytes.decode("utf-8", errors="replace"))
    except Exception as e:
        raise EvalPayloadError(f"response body not JSON: {e}") from e

    candidates = body.get("candidates")
    if not candidates:
        raise EvalPayloadError(f"response missing candidates: {body}")

    cand0 = candidates[0]
    finish_reason = cand0.get("finishReason")

    if finish_reason == "SAFETY":
        raise EvalServerError(f"Gemini SAFETY finishReason: {body}")
    if finish_reason and finish_reason not in ("STOP", "MAX_TOKENS"):
        # OTHER / RECITATION / etc.
        raise EvalServerError(
            f"Gemini unexpected finishReason {finish_reason!r}: {body}"
        )

    try:
        text = cand0["content"]["parts"][0]["text"]
    except Exception as e:
        raise EvalPayloadError(
            f"response missing candidates[0].content.parts[0].text: {body}"
        ) from e

    if not isinstance(text, str):
        raise EvalPayloadError(
            f"parts[0].text not a string: got {type(text).__name__}"
        )

    stripped = _strip_code_fences(text)

    try:
        verdict = json.loads(stripped)
    except Exception as e:
        # § 12a item 6 — MAX_TOKENS + unparseable inner JSON => raise with
        # the load-bearing reason marker; otherwise also raise.
        if finish_reason == "MAX_TOKENS":
            raise EvalPayloadError(
                f"{TRUNCATED_UNPARSEABLE_MARKER}: MAX_TOKENS verdict not JSON: "
                f"{stripped[:256]}: {e}"
            ) from e
        raise EvalPayloadError(
            f"verdict text not JSON-parseable: {stripped[:256]}: {e}"
        ) from e

    if finish_reason == "MAX_TOKENS":
        warnings.append(WARNING_TRUNCATED_MAX_TOKENS)

    if not isinstance(verdict, dict):
        raise EvalPayloadError(
            f"verdict not a JSON object: {type(verdict).__name__}"
        )

    score = verdict.get("score")
    reasoning = verdict.get("reasoning")
    if score is None:
        raise EvalPayloadError(f"verdict missing score: {verdict}")
    if reasoning is None:
        raise EvalPayloadError(f"verdict missing reasoning: {verdict}")

    try:
        score = float(score)
    except Exception as e:
        raise EvalPayloadError(
            f"verdict score not coercible to float: {score!r}: {e}"
        ) from e

    raw_md: dict[str, Any] = {
        "usage": body.get("usageMetadata", {}) or {},
        "finish_reason": finish_reason,
        "warnings": warnings,
    }

    # Score clipping with a warning surfaced in raw_metadata.
    if score < 0.0 or score > 1.0:
        raw_md["score_clipped"] = {"original": score}
        warnings.append(WARNING_SCORE_CLIPPED)
        score = max(0.0, min(1.0, score))

    return score, str(reasoning), raw_md


# ---------------------------------------------------------------------------
# Backoff helpers (§ 12a items 7, 8, 9)
# ---------------------------------------------------------------------------

def _jitter(base: float) -> float:
    """Multiplicative ±20% uniform jitter. Sample in [base*0.8, base*1.2]."""
    return random.uniform(base * 0.8, base * 1.2)


def _retry_after_seconds(headers: Optional[dict]) -> Optional[float]:
    """Extract Retry-After header value (in seconds, capped at 30s).

    Honors only the integer-seconds form — HTTP-date form is rare for
    generateContent and not parsed here.
    """
    if not headers:
        return None
    raw = (
        headers.get("Retry-After")
        or headers.get("retry-after")
        or headers.get("RETRY-AFTER")
    )
    if not raw:
        return None
    try:
        v = float(raw)
    except (TypeError, ValueError):
        return None
    if v < 0:
        return None
    return min(v, RETRY_AFTER_CAP_S)


def _schedule_for(code: Optional[int]) -> list[float]:
    """Pick the backoff base schedule for a given retry signal."""
    if code == 503:
        return RETRY_SCHEDULE_503
    if code == 504:
        return RETRY_SCHEDULE_504
    return RETRY_SCHEDULE_DEFAULT


def _max_retries_for(code: Optional[int]) -> int:
    """Pick the retry count for a given HTTP status code (or network)."""
    if code is None:
        return NETWORK_RETRY_COUNT
    return RETRY_COUNTS.get(code, 0)


# ---------------------------------------------------------------------------
# score_artifact — public entrypoint
# ---------------------------------------------------------------------------

def score_artifact(
    *,
    artifact_path: Path,
    artifact_modality: ArtifactModality,
    prompt: str,
    judge_id: str,
    model_id: str = DEFAULT_MODEL_ID,
    api_key_env_var: str = DEFAULT_AUTH_ENV_VAR,
    timeout_s: float = DEFAULT_TIMEOUT_S,
    transport: Optional[Callable] = None,
) -> EvalProviderResult:
    """Score one artifact via Gemini Vision (gemini-3.1-pro-preview).

    Args:
        artifact_path: Local path to the artifact file. Must exist.
        artifact_modality: "image" | "video" | "audio".
        prompt: The rubric / scoring prompt the model should follow. The
            adapter does NOT mangle this — caller is responsible for
            instructing the model to return ``{"score": float, "reasoning": str}``.
        judge_id: Opaque label threaded through into raw_metadata so panel
            callers can correlate which judge produced this verdict.
        model_id: Defaults to DEFAULT_MODEL_ID. Override only for testing
            or future model rotations.
        api_key_env_var: env var to read the API key from. Defaults to
            GEMINI_API_KEY with GOOGLE_API_KEY as fallback (only when the
            default env var name is used).
        timeout_s: per-call timeout.
        transport: callable (url, *, headers, body, timeout) -> response.
            Default uses urllib.request. Tests pass a mock.

    Returns:
        EvalProviderResult with score in [0.0, 1.0], reasoning text, cost
        estimate, model id echoed back, optional request_id, and raw
        metadata (incl. warnings list).

    Raises:
        EvalAuthError on 401/403 / missing key.
        EvalQuotaError on 402 (reserved).
        EvalPayloadError on 400/404 / artifact missing / artifact too large
            / response unparseable / verdict missing score|reasoning.
        EvalRateLimitError on 429 after retry exhaustion.
        EvalServerError on 5xx after retry exhaustion / SAFETY finishReason.
        EvalNetworkError on URLError / socket.timeout after retry exhaustion.
    """
    if artifact_modality not in SUPPORTED_MODALITIES:
        raise EvalPayloadError(
            f"unsupported artifact_modality {artifact_modality!r}; "
            f"expected one of {'/'.join(SUPPORTED_MODALITIES)}"
        )

    api_key = _resolve_api_key(api_key_env_var)
    if not api_key:
        raise EvalAuthError(f"env var {api_key_env_var} not set")

    artifact_part = _build_part_for_artifact(
        Path(artifact_path), artifact_modality
    )

    # § 12a item 2 — camelCase request body.
    body_obj = {
        "contents": [
            {
                "role": "user",
                "parts": [
                    {"text": prompt},
                    artifact_part,
                ],
            }
        ],
        "generationConfig": {
            "temperature": 0.0,
            "maxOutputTokens": 1024,
            "responseMimeType": "application/json",
        },
    }
    body_bytes = json.dumps(body_obj).encode("utf-8")

    headers = {
        "x-goog-api-key": api_key,
        "Content-Type": "application/json",
    }
    url = ENDPOINT_TEMPLATE.format(model=model_id)
    tport = transport or _default_transport

    # ---- Retry loop -------------------------------------------------------
    # Pattern: try; on retryable signal sleep & continue; on terminal
    # signal raise. Each retryable code carries its own retry budget +
    # base schedule.
    attempts: dict[Any, int] = {}  # signal -> attempts so far
    while True:
        try:
            resp = tport(
                url, headers=headers, body=body_bytes, timeout=timeout_s
            )
            with resp:
                code = getattr(resp, "status", 200)
                resp_headers = dict(getattr(resp, "headers", {}) or {})
                resp_body = resp.read()
            if code != 200:
                # Non-200 status WITHOUT urllib raising — synthesize the
                # error and feed it through the retry handler.
                err = _classify_http_error(code, resp_body or b"")
                handled = _handle_retryable(
                    err, code=code, headers=resp_headers,
                    attempts=attempts, origin=None,
                )
                if handled is None:
                    raise err
                time.sleep(handled)
                continue

            score, reasoning, raw_md = _parse_response(resp_body)
            usage = raw_md.get("usage", {}) or {}
            input_tokens = int(usage.get("promptTokenCount") or 0)

            # Output tokens with totalTokenCount fallback.
            cand_tokens = usage.get("candidatesTokenCount")
            if cand_tokens is None:
                total_tokens = usage.get("totalTokenCount")
                if total_tokens is not None:
                    output_tokens = max(0, int(total_tokens) - input_tokens)
                else:
                    output_tokens = 0
                    # raw_md["warnings"] is guaranteed populated by _parse_response
                    raw_md["warnings"].append(WARNING_MISSING_TOKEN_COUNT)
            else:
                output_tokens = int(cand_tokens or 0)

            long_ctx = input_tokens > _long_context_threshold(model_id)
            cost = _compute_cost(
                model_id,
                input_tokens,
                output_tokens,
                long_context=long_ctx,
            )

            # Surface judge_id + modality on raw_metadata for panel
            # correlation downstream.
            raw_md["judge_id"] = judge_id
            raw_md["artifact_modality"] = artifact_modality

            request_id = (
                resp_headers.get("x-request-id")
                or resp_headers.get("X-Request-Id")
                or resp_headers.get("request-id")
            )

            return EvalProviderResult(
                score=score,
                reasoning=reasoning,
                cost_usd=cost,
                model_used=model_id,
                request_id=request_id,
                raw_metadata=raw_md,
            )

        except urllib.error.HTTPError as e:
            err_body = b""
            try:
                err_body = e.read() if hasattr(e, "read") else b""
            except Exception:  # noqa: BLE001
                err_body = b""
            err = _classify_http_error(e.code, err_body)
            err_headers = dict(getattr(e, "headers", {}) or {})
            handled = _handle_retryable(
                err, code=e.code, headers=err_headers, attempts=attempts,
                origin=e,
            )
            if handled is None:
                # Non-retryable: re-raise.
                raise err from e
            # `handled` is a sleep duration; loop continues.
            time.sleep(handled)
            continue

        except (urllib.error.URLError, socket.timeout, TimeoutError,
                ConnectionError) as e:
            # Network / timeout — retried NETWORK_RETRY_COUNT times.
            attempts.setdefault("network", 0)
            attempts["network"] += 1
            if attempts["network"] > NETWORK_RETRY_COUNT:
                raise EvalNetworkError(f"network: {e}") from e
            base_idx = attempts["network"] - 1
            schedule = _schedule_for(None)
            base = schedule[min(base_idx, len(schedule) - 1)]
            time.sleep(_jitter(base))
            continue


def _handle_retryable(
    err: EvalProviderError,
    *,
    code: Optional[int],
    headers: Optional[dict],
    attempts: dict,
    origin: Optional[BaseException],
) -> Optional[float]:
    """Decide whether `err` should be retried.

    Returns a sleep duration (seconds, jittered, with optional Retry-After
    override) if the error is retryable AND budget remains. Returns None
    otherwise (caller raises).
    """
    if isinstance(err, (EvalAuthError, EvalQuotaError, EvalPayloadError)):
        return None

    # Determine retry budget by code.
    budget = _max_retries_for(code)
    if budget <= 0:
        return None

    attempts.setdefault(code, 0)
    attempts[code] += 1
    if attempts[code] > budget:
        return None

    # Compute sleep: Retry-After if present (capped) overrides schedule.
    base_idx = attempts[code] - 1
    schedule = _schedule_for(code)
    base = schedule[min(base_idx, len(schedule) - 1)]

    retry_after = _retry_after_seconds(headers)
    if retry_after is not None:
        return retry_after
    return _jitter(base)


__all__ = [
    "score_artifact",
    "EvalProviderResult",
    "EvalProviderError",
    "EvalAuthError",
    "EvalQuotaError",
    "EvalPayloadError",
    "EvalRateLimitError",
    "EvalServerError",
    "EvalNetworkError",
    "ArtifactModality",
    "SUPPORTED_MODALITIES",
    "DEFAULT_MODEL_ID",
    "DEFAULT_AUTH_ENV_VAR",
    "PRIMARY_AUTH_ENV_VAR",
    "FALLBACK_AUTH_ENV_VAR",
    "INLINE_DATA_MAX_BYTES",
    "INLINE_DATA_RAW_FALLBACK_THRESHOLD",
    "RETRY_AFTER_CAP_S",
    "ENDPOINT_TEMPLATE",
    "WARNING_SCORE_CLIPPED",
    "WARNING_TRUNCATED_MAX_TOKENS",
    "WARNING_MISSING_TOKEN_COUNT",
    "TRUNCATED_UNPARSEABLE_MARKER",
]
