"""
learning_engine.py — Pattern detection from gate verdicts and human overrides.

Learns from:
1. Gate verdicts — which models/prompts/shot types produce passing results
2. Human overrides — when JT approves a gate-failed shot (false negative) or
   rejects a gate-passed shot (false positive)
3. Retry patterns — which failure types are transient vs persistent
4. Happy accidents — unscripted takes that JT selected

Produces:
1. Pattern reports (JSON)
2. Gate calibration suggestions
3. Pipeline-learnings.md updates (append to shared Dropbox file)

Storage: projects/{project}/state/learning/ as JSON
Cross-project: ~/Dropbox/Claude_Config/memory/pipeline-learnings-auto.md
"""

import json
import logging
import os
import tempfile
import time
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Optional

from recoil.pipeline.core.cost import read_cost_from_record_safe

logger = logging.getLogger(__name__)

# Shared insights path (synced via Dropbox across machines)
_SHARED_LEARNINGS_PATH = Path.home() / "Dropbox" / "Claude_Config" / "memory" / "pipeline-learnings-auto.md"

# Minimum samples before LearningEngine overrides static chain ordering
AUTONOMY_THRESHOLD = 8


@dataclass
class StrategyStats:
    strategy_name: str
    success_rate: float
    n_samples: int
    avg_cost_usd: float
    avg_latency_s: float
    last_used_ts: float


class LearningEngine:
    """Ingests generation results and produces insights.

    Not thread-safe — designed for single-threaded production loop.
    """

    def __init__(self, project: str, state_dir: Optional[Path] = None):
        """
        Args:
            project: Project name.
            state_dir: Learning data directory. Defaults to
                projects/{project}/state/learning/
        """
        self.project = project

        if state_dir:
            self._state_dir = state_dir
        else:
            from recoil.core.paths import ProjectPaths
            self._state_dir = ProjectPaths.for_project(project).learning_dir
        self._state_dir.mkdir(parents=True, exist_ok=True)

        # In-memory accumulators (flushed periodically)
        self._verdicts: list[dict] = []
        self._overrides: list[dict] = []
        self._retries: list[dict] = []

    def ingest_result(
        self,
        shot_id: str,
        step_result,
        shot_data: Optional[dict] = None,
        model: str = "",
        phase: str = "",
        shot_type: str = "",
    ) -> None:
        """Ingest a single generation result for learning.

        Call after every StepRunner execution (success or failure).
        """
        gv = step_result.gate_verdict
        entry = {
            "timestamp": time.time(),
            "project": self.project,
            "shot_id": shot_id,
            "model": model or step_result.model,
            "phase": phase or step_result.pipeline,
            "shot_type": shot_type,
            "success": step_result.success,
            "cost": step_result.cost_usd,
            "gate_name": gv.gate_name if gv else None,
            "gate_passed": gv.passed if gv else None,
            "gate_reason": gv.reason if gv else None,
            "gate_deferred": gv.deferred if gv else None,
            "final_state": step_result.final_state,
        }

        # Extract gate details for pattern analysis
        if gv and gv.details:
            entry["gate_details"] = {
                k: v for k, v in gv.details.items()
                if k in ("total_score", "mismatches", "failure_category",
                         "drift_count", "skipped")
            }

        self._verdicts.append(entry)

        # Auto-flush every 50 entries
        if len(self._verdicts) >= 50:
            self.flush()

    def ingest_override(
        self,
        shot_id: str,
        override_type: str,
        gate_name: str,
        gate_passed: bool,
        human_decision: str,
        notes: str = "",
    ) -> None:
        """Record when a human overrides a gate decision.

        Args:
            override_type: "false_negative" (approved despite gate fail) or
                          "false_positive" (rejected despite gate pass).
            gate_name: Which gate was overridden.
            gate_passed: Whether the gate originally passed.
            human_decision: "approved" or "rejected".
            notes: Human notes on why.
        """
        self._overrides.append({
            "timestamp": time.time(),
            "project": self.project,
            "shot_id": shot_id,
            "override_type": override_type,
            "gate_name": gate_name,
            "gate_passed": gate_passed,
            "human_decision": human_decision,
            "notes": notes,
        })

    def ingest_retry(
        self,
        shot_id: str = "",
        pass_id: str = "",
        failure_category: str = "",
        failure_mode: str = "",
        strategy_applied: str = "",
        prior_strategies: Optional[list[str]] = None,
        retry_number: int = 0,
        succeeded: bool = False,
        fix_applied: Optional[str] = None,
        cost_usd: float = 0.0,
        latency_s: float = 0.0,
        confidence: float = 0.0,
        model: str = "",
        selection_basis: str = "",
        source: str = "engine",
        notes: str = "",
        event_kind: str = "strategy_retry",
        granularity: str = "",
        target_id: str = "",
        stage_from: str = "",
        stage_to: str = "",
        stages_run: Optional[list[str]] = None,
        paid_requested: bool = False,
        paid_approved: bool = False,
        before: Optional[dict] = None,
        after: Optional[dict] = None,
        skipped_locked: Optional[list[dict]] = None,
        outcome: str = "",
        artifact_links: Optional[dict] = None,
        linear_issue_ids: Optional[list[str]] = None,
    ) -> None:
        """Record retry outcome for transient vs persistent pattern analysis.

        All new parameters have defaults so existing callers still work.
        """
        self._retries.append({
            "timestamp": time.time(),
            "project": self.project,
            "event_kind": event_kind,
            "shot_id": shot_id,
            "pass_id": pass_id,
            "failure_category": failure_category,
            "failure_mode": failure_mode,
            "strategy_applied": strategy_applied,
            "prior_strategies": prior_strategies or [],
            "retry_number": retry_number,
            "succeeded": succeeded,
            "fix_applied": fix_applied,
            "cost_usd": cost_usd,
            "latency_s": latency_s,
            "confidence": confidence,
            "model": model,
            "selection_basis": selection_basis,
            "source": source,
            "notes": notes,
            "granularity": granularity,
            "target_id": target_id,
            "stage_from": stage_from,
            "stage_to": stage_to,
            "stages_run": stages_run or [],
            "paid_requested": paid_requested,
            "paid_approved": paid_approved,
            "before": before or {},
            "after": after or {},
            "skipped_locked": skipped_locked or [],
            "outcome": outcome,
            "artifact_links": artifact_links or {},
            "linear_issue_ids": linear_issue_ids or [],
        })

    def query_strategy_stats(
        self,
        failure_mode: str,
        model: str = "",
        min_samples: int = 1,
    ) -> dict[str, "StrategyStats"]:
        """Aggregate retries.jsonl stats for a given (failure_mode, model) pair.

        Returns {strategy_name: StrategyStats(...)}.
        Reads retries.jsonl, filters by failure_mode (and model if provided),
        groups by strategy_applied, computes success_rate, avg_cost, avg_latency.
        Only includes strategies with n >= min_samples.
        """
        entries = self._load_jsonl("retries.jsonl")

        # Filter by failure_mode (and model if provided)
        filtered = []
        for e in entries:
            if e.get("event_kind", "strategy_retry") != "strategy_retry":
                continue
            if e.get("failure_mode") != failure_mode:
                continue
            if model and e.get("model") != model:
                continue
            filtered.append(e)

        # Group by strategy_applied
        groups: dict[str, list[dict]] = defaultdict(list)
        for e in filtered:
            strat = e.get("strategy_applied", "")
            if strat:
                groups[strat].append(e)

        result: dict[str, StrategyStats] = {}
        for strat_name, records in groups.items():
            n = len(records)
            if n < min_samples:
                continue
            successes = sum(1 for r in records if r.get("succeeded"))
            total_cost = sum(read_cost_from_record_safe(r) for r in records)
            total_latency = sum(r.get("latency_s", 0.0) for r in records)
            last_ts = max(r.get("timestamp", 0.0) for r in records)
            result[strat_name] = StrategyStats(
                strategy_name=strat_name,
                success_rate=successes / n if n else 0.0,
                n_samples=n,
                avg_cost_usd=total_cost / n if n else 0.0,
                avg_latency_s=total_latency / n if n else 0.0,
                last_used_ts=last_ts,
            )

        return result

    def recommend_strategy(
        self,
        failure_mode: str,
        model: str,
        already_tried: list[str],
    ) -> Optional[str]:
        """Return the best unused strategy name for this (failure_mode, model),
        or None if n < AUTONOMY_THRESHOLD (caller falls back to static chain).

        - If total samples for this (failure_mode, model) < AUTONOMY_THRESHOLD: return None.
        - Else: filter out `already_tried`, sort remaining by success_rate desc.
        - Quarantine: skip strategies with success_rate == 0.0 and n >= 5.
        - Return the top strategy name, or None if all filtered out.
        """
        stats = self.query_strategy_stats(failure_mode, model=model)

        # Check total sample count across all strategies
        total_samples = sum(s.n_samples for s in stats.values())
        if total_samples < AUTONOMY_THRESHOLD:
            return None

        # Filter out already_tried, sort by success_rate desc
        candidates = []
        for name, s in stats.items():
            if name in already_tried:
                continue
            # Quarantine: skip strategies with 0% success and enough samples
            if s.success_rate == 0.0 and s.n_samples >= 5:
                continue
            candidates.append(s)

        if not candidates:
            return None

        candidates.sort(key=lambda s: s.success_rate, reverse=True)
        return candidates[0].strategy_name

    def write_seed_data(self) -> None:
        """Seed retries.jsonl with empirical data from pipeline-learnings.md.
        Idempotent — checks for existing _seed records and skips if present."""
        # Check idempotency: look for existing seed records
        existing = self._load_jsonl("retries.jsonl")
        for entry in existing:
            if entry.get("_seed"):
                logger.info("Seed data already present in retries.jsonl — skipping.")
                return

        seed_record = {
            "timestamp": 1713168000.0,
            "project": "tartarus",
            "pass_id": "EP001_PASS_002_TEST_480P_FAL",
            "model": "seeddance-2.0",
            "failure_mode": "identity_drift",
            "strategy_applied": "add_turnaround_angles",
            "prior_strategies": [],
            "retry_number": 1,
            "succeeded": True,
            "cost_usd": 1.67,
            "latency_s": 237.0,
            "confidence": 1.0,
            "selection_basis": "static_chain",
            "source": "seed",
            "notes": "Empirical verification of _gather_identity_refs hero+turnaround fix (pipeline-learnings \u00a710g). JT confirmed identity held.",
            "_seed": True,
        }
        self._append_jsonl("retries.jsonl", [seed_record])
        logger.info("Wrote seed data to retries.jsonl.")

    def flush(self) -> None:
        """Write accumulated data to disk."""
        if self._verdicts:
            self._append_jsonl("verdicts.jsonl", self._verdicts)
            self._verdicts.clear()
        if self._overrides:
            self._append_jsonl("overrides.jsonl", self._overrides)
            self._overrides.clear()
        if self._retries:
            self._append_jsonl("retries.jsonl", self._retries)
            self._retries.clear()

    def generate_report(self) -> dict:
        """Analyze accumulated data and produce a pattern report.

        Returns a dict with pass rates by model, shot type, gate calibration
        suggestions, and retry pattern analysis.
        """
        verdicts = self._load_jsonl("verdicts.jsonl")
        overrides = self._load_jsonl("overrides.jsonl")

        if not verdicts:
            return {"status": "no_data", "message": "No verdicts to analyze yet."}

        # Pass rate by model
        model_stats = defaultdict(lambda: {"pass": 0, "fail": 0, "total": 0, "cost": 0.0})
        for v in verdicts:
            model = v.get("model", "unknown")
            model_stats[model]["total"] += 1
            model_stats[model]["cost"] += v.get("cost", 0)
            if v.get("success"):
                model_stats[model]["pass"] += 1
            else:
                model_stats[model]["fail"] += 1

        model_report = {}
        for model, stats in model_stats.items():
            total = stats["total"]
            model_report[model] = {
                "total": total,
                "pass_rate": round(stats["pass"] / total, 3) if total else 0,
                "total_cost": round(stats["cost"], 4),
            }

        # Pass rate by shot type
        type_stats = defaultdict(lambda: {"pass": 0, "fail": 0, "total": 0})
        for v in verdicts:
            st = v.get("shot_type", "unknown")
            type_stats[st]["total"] += 1
            if v.get("success"):
                type_stats[st]["pass"] += 1
            else:
                type_stats[st]["fail"] += 1

        type_report = {}
        for st, stats in type_stats.items():
            total = stats["total"]
            type_report[st] = {
                "total": total,
                "pass_rate": round(stats["pass"] / total, 3) if total else 0,
            }

        # Gate calibration (false positive/negative rates from overrides)
        gate_cal = defaultdict(lambda: {"false_positive": 0, "false_negative": 0, "total_overrides": 0})
        for o in overrides:
            gate = o.get("gate_name", "unknown")
            gate_cal[gate]["total_overrides"] += 1
            if o.get("override_type") == "false_positive":
                gate_cal[gate]["false_positive"] += 1
            elif o.get("override_type") == "false_negative":
                gate_cal[gate]["false_negative"] += 1

        report = {
            "generated_at": time.time(),
            "project": self.project,
            "total_verdicts": len(verdicts),
            "total_overrides": len(overrides),
            "by_model": model_report,
            "by_shot_type": type_report,
            "gate_calibration": dict(gate_cal),
        }

        # Write report to disk
        report_path = self._state_dir / "latest_report.json"
        self._write_json(report_path, report)

        return report

    def update_shared_learnings(self, findings: list[str]) -> None:
        """Append findings to the shared pipeline-learnings-auto.md.

        This file is synced via Dropbox and visible across all machines.
        Only append truly novel findings (not duplicates of existing entries).
        """
        if not findings:
            return

        _SHARED_LEARNINGS_PATH.parent.mkdir(parents=True, exist_ok=True)

        existing = ""
        if _SHARED_LEARNINGS_PATH.exists():
            existing = _SHARED_LEARNINGS_PATH.read_text(encoding="utf-8")

        # Filter out duplicate findings
        new_findings = [f for f in findings if f not in existing]
        if not new_findings:
            return

        timestamp = time.strftime("%Y-%m-%d %H:%M")
        block = f"\n\n## Auto-learned: {self.project} ({timestamp})\n"
        for f in new_findings:
            block += f"- {f}\n"

        with open(_SHARED_LEARNINGS_PATH, "a", encoding="utf-8") as fh:
            fh.write(block)

        logger.info("Appended %d findings to shared learnings", len(new_findings))

    def _append_jsonl(self, filename: str, entries: list[dict]) -> None:
        """Append entries to a JSONL file."""
        path = self._state_dir / filename
        with open(path, "a", encoding="utf-8") as f:
            for entry in entries:
                f.write(json.dumps(entry, default=str) + "\n")

    def _load_jsonl(self, filename: str) -> list[dict]:
        """Load all entries from a JSONL file."""
        path = self._state_dir / filename
        if not path.exists():
            return []
        entries = []
        for line in path.read_text(encoding="utf-8").splitlines():
            line = line.strip()
            if line:
                try:
                    entries.append(json.loads(line))
                except json.JSONDecodeError:
                    continue
        return entries

    def _write_json(self, path: Path, data: dict) -> None:
        """Atomic JSON write."""
        fd, tmp = tempfile.mkstemp(
            dir=str(path.parent), prefix=".learn_", suffix=".tmp"
        )
        try:
            with os.fdopen(fd, "w", encoding="utf-8") as f:
                json.dump(data, f, indent=2, default=str)
            os.replace(tmp, str(path))
        except Exception:
            try:
                os.unlink(tmp)
            except OSError:
                pass
            raise
