"""Provider-call observability.

Every adapter call (submit + terminal poll) writes one row to
recoil/execution/observability.sqlite. The DB is created on first
write. Schema is immutable within a major version — do NOT alter
columns without a migration plan.

Retention: 90 days. trim_old() is called by the nightly drift report
(recoil/tools/provider_drift_report.py) before producing its summary.
"""

from __future__ import annotations

import logging
import os
import sqlite3
import threading
import time
from contextlib import contextmanager
from pathlib import Path
from typing import Iterator, Optional

logger = logging.getLogger(__name__)

_DB_PATH = Path(__file__).resolve().parents[1] / "observability.sqlite"
_WRITE_LOCK = threading.Lock()


SCHEMA_SQL = """
CREATE TABLE IF NOT EXISTS provider_calls (
    id           INTEGER PRIMARY KEY AUTOINCREMENT,
    ts           TEXT    NOT NULL,
    provider     TEXT    NOT NULL,
    model        TEXT    NOT NULL,
    tier         TEXT    NOT NULL,
    duration_s   REAL,
    listed_cost  REAL,
    observed_cost REAL,
    latency_ms   INTEGER,
    status       TEXT    NOT NULL,
    task_id      TEXT,
    shot_id      TEXT,
    error        TEXT
);

CREATE INDEX IF NOT EXISTS idx_provider_calls_ts ON provider_calls (ts);
CREATE INDEX IF NOT EXISTS idx_provider_calls_provider_model ON provider_calls (provider, model);
CREATE INDEX IF NOT EXISTS idx_provider_calls_status ON provider_calls (status);
"""


def db_path() -> Path:
    return _DB_PATH


@contextmanager
def _connect() -> Iterator[sqlite3.Connection]:
    conn = sqlite3.connect(str(_DB_PATH), timeout=10.0)
    try:
        yield conn
        conn.commit()
    finally:
        conn.close()


def ensure_schema() -> None:
    """Create the DB + tables if they don't exist."""
    _DB_PATH.parent.mkdir(parents=True, exist_ok=True)
    with _WRITE_LOCK, _connect() as conn:
        conn.executescript(SCHEMA_SQL)


def record_call(
    *,
    provider: str,
    model: str,
    tier: str,
    status: str,
    duration_s: Optional[float] = None,
    listed_cost: Optional[float] = None,
    observed_cost: Optional[float] = None,
    latency_ms: Optional[int] = None,
    task_id: Optional[str] = None,
    shot_id: Optional[str] = None,
    error: Optional[str] = None,
) -> None:
    """Write one provider_calls row. Silently swallows DB errors —
    observability must never crash a production generation."""
    try:
        ensure_schema()
        ts = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
        with _WRITE_LOCK, _connect() as conn:
            conn.execute(
                """
                INSERT INTO provider_calls
                    (ts, provider, model, tier, duration_s, listed_cost,
                     observed_cost, latency_ms, status, task_id, shot_id, error)
                VALUES (?,?,?,?,?,?,?,?,?,?,?,?)
                """,
                (ts, provider, model, tier, duration_s, listed_cost,
                 observed_cost, latency_ms, status, task_id, shot_id, error),
            )
    except Exception as e:
        logger.warning("observability.record_call failed: %s", e)


def trim_old(days: int = 90) -> int:
    """Delete rows older than N days. Returns rows deleted."""
    cutoff_ts = time.strftime(
        "%Y-%m-%dT%H:%M:%SZ", time.gmtime(time.time() - days * 86400)
    )
    try:
        ensure_schema()
        with _WRITE_LOCK, _connect() as conn:
            cur = conn.execute(
                "DELETE FROM provider_calls WHERE ts < ?", (cutoff_ts,)
            )
            return cur.rowcount or 0
    except Exception as e:
        logger.warning("observability.trim_old failed: %s", e)
        return 0


def query_drift(
    *,
    since_days: int = 30,
    min_samples: int = 5,
) -> list[dict]:
    """Return per (provider, model, tier) aggregated drift rows.

    Drift = observed_cost / listed_cost. Groups with < min_samples
    observed rows are skipped (noise floor).
    """
    cutoff_ts = time.strftime(
        "%Y-%m-%dT%H:%M:%SZ", time.gmtime(time.time() - since_days * 86400)
    )
    try:
        ensure_schema()
        with _connect() as conn:
            cur = conn.execute(
                """
                SELECT provider, model, tier,
                       COUNT(*) AS n,
                       SUM(CASE WHEN observed_cost IS NOT NULL THEN 1 ELSE 0 END) AS n_observed,
                       AVG(listed_cost) AS avg_listed,
                       AVG(observed_cost) AS avg_observed
                FROM provider_calls
                WHERE ts >= ?
                  AND status = 'COMPLETED'
                GROUP BY provider, model, tier
                """,
                (cutoff_ts,),
            )
            rows: list[dict] = []
            for provider, model, tier, n, n_obs, avg_listed, avg_observed in cur.fetchall():
                if (n_obs or 0) < min_samples:
                    continue
                if not avg_listed or not avg_observed:
                    continue
                ratio = avg_observed / avg_listed
                rows.append({
                    "provider": provider,
                    "model": model,
                    "tier": tier,
                    "n": n,
                    "n_observed": n_obs,
                    "avg_listed": avg_listed,
                    "avg_observed": avg_observed,
                    "drift_ratio": ratio,
                })
            return rows
    except Exception as e:
        logger.warning("observability.query_drift failed: %s", e)
        return []


__all__ = ["db_path", "ensure_schema", "record_call", "trim_old", "query_drift"]
