#!/usr/bin/env python3
"""Provider cost drift report.

Reads recoil/execution/observability.sqlite and emits a table of
(provider, model, tier) with observed/listed cost ratio. Flags any
row where |ratio - 1| >= 15% as DRIFT.

Runs nightly via cron in production. Also trims rows older than 90 days.

Usage:
  python3 recoil/tools/provider_drift_report.py
  python3 recoil/tools/provider_drift_report.py --since-days 7 --threshold 0.20
  python3 recoil/tools/provider_drift_report.py --json
"""

from __future__ import annotations

import argparse
import json
import sys
from pathlib import Path

# Allow running without installing recoil as a package.
_RECOIL = Path(__file__).resolve().parents[1]
if str(_RECOIL) not in sys.path:
    sys.path.insert(0, str(_RECOIL))

from recoil.execution.providers import observability as obs


def main() -> int:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--since-days", type=int, default=30)
    parser.add_argument("--min-samples", type=int, default=5)
    parser.add_argument("--threshold", type=float, default=0.15,
                        help="Flag drift >= this fraction (default: 0.15 = 15%%)")
    parser.add_argument("--trim-days", type=int, default=90,
                        help="Delete rows older than this many days (default: 90)")
    parser.add_argument("--no-trim", action="store_true")
    parser.add_argument("--json", action="store_true")
    args = parser.parse_args()

    if not args.no_trim:
        trimmed = obs.trim_old(days=args.trim_days)
        if trimmed and not args.json:
            print(f"(trimmed {trimmed} rows older than {args.trim_days}d)", file=sys.stderr)

    rows = obs.query_drift(since_days=args.since_days, min_samples=args.min_samples)

    flagged = []
    for r in rows:
        drift = abs(r["drift_ratio"] - 1.0)
        r["drift_pct"] = round(drift * 100, 1)
        r["flagged"] = drift >= args.threshold
        if r["flagged"]:
            flagged.append(r)

    if args.json:
        json.dump({"rows": rows, "flagged": flagged, "threshold": args.threshold}, sys.stdout, indent=2)
        sys.stdout.write("\n")
        return 1 if flagged else 0

    if not rows:
        print("No observability rows in window.")
        return 0

    print(f"Provider drift report — {len(rows)} groups, threshold={args.threshold*100:.0f}%")
    print("-" * 96)
    print(f"{'provider':<10}{'model':<20}{'tier':<20}{'n':>6}{'listed':>12}{'observed':>12}{'drift%':>10}")
    for r in sorted(rows, key=lambda x: -abs(x["drift_ratio"] - 1.0)):
        marker = "  DRIFT" if r["flagged"] else ""
        print(
            f"{r['provider']:<10}{r['model']:<20}{r['tier']:<20}"
            f"{r['n']:>6}{r['avg_listed']:>12.4f}{r['avg_observed']:>12.4f}"
            f"{r['drift_pct']:>10.1f}{marker}"
        )

    return 1 if flagged else 0


if __name__ == "__main__":
    sys.exit(main())
