#!/usr/bin/env python3
"""feedback_report.py — Post-series failure analysis.

Reads ExecutionStore shot data and feedback JSONL logs for a project.
Aggregates failures by character, location, shot type, and model.
Outputs a markdown report for human review.

Usage:
    python3 tools/feedback_report.py --project tartarus
    python3 tools/feedback_report.py --project tartarus --episode 1
    python3 tools/feedback_report.py --project tartarus --output report.md
"""

import argparse
import json
import sys
from collections import Counter, defaultdict
from pathlib import Path

# Resolve project roots
_TOOL_DIR = Path(__file__).parent
_PIPELINE_ROOT = _TOOL_DIR.parent
_RECOIL_ROOT = _PIPELINE_ROOT.parent
sys.path.insert(0, str(_RECOIL_ROOT))
sys.path.insert(0, str(_PIPELINE_ROOT))

from recoil.core.paths import projects_root, ProjectPaths

try:
    from recoil.execution.feedback.constants import CROP_CLOSEUP_CREEP_THRESHOLD
except ImportError:
    CROP_CLOSEUP_CREEP_THRESHOLD = 0.05


def load_shots(project: str, episode: int | None = None) -> list[dict]:
    """Load all shot records from ExecutionStore JSON files."""
    shots_dir = ProjectPaths.for_project(project).shots_dir
    if not shots_dir.exists():
        print(f"No shots directory found: {shots_dir}", file=sys.stderr)
        return []

    shots = []
    for f in sorted(shots_dir.glob("*.json")):
        try:
            data = json.loads(f.read_text())
            if episode is not None:
                ep_id = data.get("episode_id", "")
                ep_num = int("".join(c for c in ep_id if c.isdigit()) or "0")
                if ep_num != episode:
                    continue
            shots.append(data)
        except (json.JSONDecodeError, ValueError):
            continue
    return shots


def load_feedback_log(project: str) -> list[dict]:
    """Load feedback JSONL log entries for a project."""
    log_path = _RECOIL_ROOT / "engine-memory" / "feedback" / "feedback_log.jsonl"
    if not log_path.exists():
        return []

    entries = []
    for line in log_path.read_text().splitlines():
        line = line.strip()
        if not line:
            continue
        try:
            entry = json.loads(line)
            if entry.get("project") == project:
                entries.append(entry)
        except json.JSONDecodeError:
            continue
    return entries


def analyze(shots: list[dict], feedback_entries: list[dict]) -> dict:
    """Analyze shots and feedback data, return aggregated stats."""
    total = len(shots)
    if total == 0:
        return {"total": 0}

    # Status distribution
    status_counts = Counter(s.get("status", "unknown") for s in shots)

    # Failure analysis
    failed_shots = [s for s in shots if "failed" in s.get("status", "")]
    deferred_shots = [s for s in shots if s.get("deferred")]

    # Group failures by gate
    gate_failures = Counter()
    failure_reasons = defaultdict(list)
    for s in shots:
        for take in s.get("takes", []):
            gv = take.get("gate_verdict", {})
            if gv and not gv.get("passed", True):
                gate_name = gv.get("gate_name", "unknown")
                gate_failures[gate_name] += 1
                reason = gv.get("reason", "")[:100]
                failure_reasons[gate_name].append({
                    "shot_id": s["shot_id"],
                    "reason": reason,
                    "model": take.get("model", ""),
                })

    # Group by model
    model_stats = defaultdict(lambda: {"total": 0, "failed": 0, "cost": 0.0})
    for s in shots:
        model = s.get("model", "unknown") or "unknown"
        model_stats[model]["total"] += 1
        if "failed" in s.get("status", ""):
            model_stats[model]["failed"] += 1
        model_stats[model]["cost"] += s.get("cost_incurred", 0)

    # Group by pipeline
    pipeline_stats = defaultdict(lambda: {"total": 0, "failed": 0})
    for s in shots:
        pipeline = s.get("pipeline", "unknown") or "unknown"
        pipeline_stats[pipeline]["total"] += 1
        if "failed" in s.get("status", ""):
            pipeline_stats[pipeline]["failed"] += 1

    # Feedback strategy effectiveness
    strategy_stats = defaultdict(lambda: {"used": 0, "passed": 0, "failed_same": 0, "failed_different": 0})
    for entry in feedback_entries:
        strat = entry.get("strategy", "unknown")
        result = entry.get("result", "")
        strategy_stats[strat]["used"] += 1
        if result == "passed":
            strategy_stats[strat]["passed"] += 1
        elif result == "failed_same":
            strategy_stats[strat]["failed_same"] += 1
        elif result in ("failed_different", "failed_worse"):
            strategy_stats[strat]["failed_different"] += 1

    # Cost analysis
    total_cost = sum(s.get("cost_incurred", 0) for s in shots)
    waste_cost = sum(s.get("retry_waste_cost", 0) for s in shots)

    # Episode breakdown
    episode_stats = defaultdict(lambda: {"total": 0, "failed": 0, "deferred": 0})
    for s in shots:
        ep = s.get("episode_id", "unknown")
        episode_stats[ep]["total"] += 1
        if "failed" in s.get("status", ""):
            episode_stats[ep]["failed"] += 1
        if s.get("deferred"):
            episode_stats[ep]["deferred"] += 1

    # CROP_TO_CLOSEUP creep tracking
    crop_count = sum(
        1 for entry in feedback_entries
        if entry.get("strategy") == "crop_to_closeup"
    )
    crop_pct = (crop_count / total * 100) if total > 0 else 0

    return {
        "total": total,
        "status_counts": dict(status_counts),
        "failed_count": len(failed_shots),
        "deferred_count": len(deferred_shots),
        "gate_failures": dict(gate_failures),
        "failure_reasons": dict(failure_reasons),
        "model_stats": dict(model_stats),
        "pipeline_stats": dict(pipeline_stats),
        "strategy_stats": dict(strategy_stats),
        "total_cost": total_cost,
        "waste_cost": waste_cost,
        "episode_stats": dict(episode_stats),
        "crop_closeup_count": crop_count,
        "crop_closeup_pct": crop_pct,
    }


def format_report(project: str, stats: dict, episode: int | None = None) -> str:
    """Format analysis stats into a markdown report."""
    lines = []
    ep_label = f" Episode {episode}" if episode else ""
    lines.append(f"# Feedback Report — {project}{ep_label}")
    lines.append("")

    if stats["total"] == 0:
        lines.append("No shots found.")
        return "\n".join(lines)

    # Summary
    lines.append("## Summary")
    lines.append(f"- **Total shots:** {stats['total']}")
    lines.append(f"- **Failed:** {stats['failed_count']} ({stats['failed_count']/stats['total']*100:.1f}%)")
    lines.append(f"- **Deferred:** {stats['deferred_count']}")
    lines.append(f"- **Total cost:** ${stats['total_cost']:.2f}")
    lines.append(f"- **Waste cost (retries):** ${stats['waste_cost']:.2f}")
    lines.append("")

    # Status distribution
    lines.append("## Status Distribution")
    for status, count in sorted(stats["status_counts"].items(), key=lambda x: -x[1]):
        lines.append(f"- {status}: {count}")
    lines.append("")

    # Gate failures
    if stats["gate_failures"]:
        lines.append("## Gate Failures")
        for gate, count in sorted(stats["gate_failures"].items(), key=lambda x: -x[1]):
            lines.append(f"### {gate} — {count} failures")
            reasons = stats["failure_reasons"].get(gate, [])
            reason_counts = Counter(r["reason"] for r in reasons)
            for reason, rc in reason_counts.most_common(5):
                lines.append(f"- ({rc}x) {reason}")
            lines.append("")

    # Model performance
    if stats["model_stats"]:
        lines.append("## Model Performance")
        lines.append("| Model | Shots | Failed | Fail % | Cost |")
        lines.append("|-------|-------|--------|--------|------|")
        for model, ms in sorted(stats["model_stats"].items()):
            fail_pct = ms["failed"] / ms["total"] * 100 if ms["total"] > 0 else 0
            lines.append(f"| {model} | {ms['total']} | {ms['failed']} | {fail_pct:.1f}% | ${ms['cost']:.2f} |")
        lines.append("")

    # Pipeline performance
    if stats["pipeline_stats"]:
        lines.append("## Pipeline Performance")
        for pipeline, ps in sorted(stats["pipeline_stats"].items()):
            fail_pct = ps["failed"] / ps["total"] * 100 if ps["total"] > 0 else 0
            lines.append(f"- **{pipeline}:** {ps['total']} shots, {ps['failed']} failed ({fail_pct:.1f}%)")
        lines.append("")

    # Feedback effectiveness
    if stats["strategy_stats"]:
        lines.append("## Feedback Strategy Effectiveness")
        lines.append("| Strategy | Used | Passed | Same Fail | Different Fail | Success % |")
        lines.append("|----------|------|--------|-----------|----------------|-----------|")
        for strat, ss in sorted(stats["strategy_stats"].items()):
            success_pct = ss["passed"] / ss["used"] * 100 if ss["used"] > 0 else 0
            lines.append(
                f"| {strat} | {ss['used']} | {ss['passed']} | "
                f"{ss['failed_same']} | {ss['failed_different']} | {success_pct:.0f}% |"
            )
        lines.append("")

    # CROP_TO_CLOSEUP creep
    if stats["crop_closeup_count"] > 0:
        lines.append("## CROP_TO_CLOSEUP Creep")
        threshold_pct = CROP_CLOSEUP_CREEP_THRESHOLD * 100
        warning = f" **ABOVE {threshold_pct:.0f}% THRESHOLD — upstream problem**" if stats["crop_closeup_pct"] > threshold_pct else ""
        lines.append(f"- Crop count: {stats['crop_closeup_count']} ({stats['crop_closeup_pct']:.1f}%){warning}")
        lines.append("")

    # Episode breakdown
    if stats["episode_stats"]:
        lines.append("## Episode Breakdown")
        lines.append("| Episode | Shots | Failed | Deferred |")
        lines.append("|---------|-------|--------|----------|")
        for ep, es in sorted(stats["episode_stats"].items()):
            lines.append(f"| {ep} | {es['total']} | {es['failed']} | {es['deferred']} |")
        lines.append("")

    return "\n".join(lines)


def main():
    parser = argparse.ArgumentParser(description="Post-series feedback analysis")
    parser.add_argument("--project", required=True, help="Project name (e.g. tartarus)")
    parser.add_argument("--episode", type=int, default=None, help="Filter to specific episode number")
    parser.add_argument("--output", type=str, default=None, help="Output file path (default: stdout)")
    parser.add_argument("--json", action="store_true", help="Output raw JSON instead of markdown")
    args = parser.parse_args()

    shots = load_shots(args.project, args.episode)
    feedback = load_feedback_log(args.project)
    stats = analyze(shots, feedback)

    if args.json:
        output = json.dumps(stats, indent=2)
    else:
        output = format_report(args.project, stats, args.episode)

    if args.output:
        Path(args.output).write_text(output)
        print(f"Report written to {args.output}")
    else:
        print(output)


if __name__ == "__main__":
    main()
