#!/usr/bin/env python3
"""Status ledger helper for dispatch runs.

This module is intentionally stdlib-only. It owns the status.json lock and
atomic write protocol so shell actors cannot accidentally lose updates.
"""

from __future__ import annotations

import argparse
import hashlib
import json
import os
from pathlib import Path
import re
import subprocess
import sys
import tempfile
import time
from datetime import datetime, timedelta, timezone
from typing import Any, Callable


# Max concurrent non-terminal dispatch runs. SSOT for the cap — the /dispatch
# skill's concurrency gate reads this (do NOT hardcode the number in the skill).
# Raised 2→4 (2026-06-07) now that Codex is on the Pro mid-tier (5× usage).
CONCURRENT_RUNS_MAX = 4

DEFAULT_BUDGET = {
    "wall_clock_s_max": 21600,
    "wall_clock_s_used": 0,
    "codex_rounds_max": 18,
    "codex_rounds_used": 0,
    "concurrent_runs_max": CONCURRENT_RUNS_MAX,
}

DETERMINISTIC_CAUSES = {"frozen_gate", "content_policy", "spec_structural"}

# Terminal dispatch run states (run reaped/converged or capped). MODULE-LEVEL
# SSOT — the single importable Python home for the terminal set, imported by
# recoil_dashboard.py (REC-229). The bash dispatch_reaper.sh:41 string is a
# separate language home kept coherent by test_recoil_dashboard.py's drift guard.
TERMINAL_STATES = {"CONVERGED_PR_CREATED", "CAPPED_NEEDS_HUMAN"}
SHLOCK = "/usr/bin/shlock"
LOCK_TIMEOUT_S = 30.0
LOCK_POLL_S = 0.1


class DispatchStatusError(Exception):
    pass


def utc_now() -> datetime:
    return datetime.now(timezone.utc)


def iso_now() -> str:
    return utc_now().replace(microsecond=0).isoformat().replace("+00:00", "Z")


def iso_from_dt(value: datetime) -> str:
    return value.replace(microsecond=0).isoformat().replace("+00:00", "Z")


def run_dir_path(raw: str) -> Path:
    return Path(raw).expanduser().resolve()


def status_path(run_dir: Path) -> Path:
    return run_dir / "status.json"


def events_path(run_dir: Path) -> Path:
    return run_dir / "events.jsonl"


class RunLock:
    def __init__(self, run_dir: Path) -> None:
        self.run_dir = run_dir
        self.lock_path = run_dir / "status.lock"
        self.acquired = False

    def __enter__(self) -> "RunLock":
        self.run_dir.mkdir(parents=True, exist_ok=True)
        deadline = time.monotonic() + LOCK_TIMEOUT_S
        while True:
            result = subprocess.run(
                [SHLOCK, "-f", str(self.lock_path), "-p", str(os.getpid())],
                stdout=subprocess.DEVNULL,
                stderr=subprocess.DEVNULL,
                check=False,
            )
            if result.returncode == 0:
                self.acquired = True
                return self
            if time.monotonic() >= deadline:
                raise DispatchStatusError(f"timed out acquiring {self.lock_path}")
            time.sleep(LOCK_POLL_S)

    def __exit__(self, exc_type: object, exc: object, tb: object) -> None:
        if not self.acquired:
            return
        try:
            current = self.lock_path.read_text(encoding="utf-8").strip()
        except FileNotFoundError:
            return
        if current == str(os.getpid()):
            try:
                self.lock_path.unlink()
            except FileNotFoundError:
                pass


def read_json(path: Path) -> dict[str, Any]:
    with path.open("r", encoding="utf-8") as handle:
        data = json.load(handle)
    if not isinstance(data, dict):
        raise DispatchStatusError(f"{path} must contain a JSON object")
    return data


def write_json_atomic(path: Path, data: dict[str, Any]) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    fd, tmp_name = tempfile.mkstemp(
        prefix=f".{path.name}.", suffix=".tmp", dir=str(path.parent)
    )
    tmp_path = Path(tmp_name)
    try:
        with os.fdopen(fd, "w", encoding="utf-8") as handle:
            json.dump(data, handle, indent=2, sort_keys=True)
            handle.write("\n")
            handle.flush()
            os.fsync(handle.fileno())
        os.replace(tmp_path, path)
    finally:
        if tmp_path.exists():
            tmp_path.unlink()


def append_event(run_dir: Path, event: str, **fields: Any) -> None:
    payload = {"at": iso_now(), "event": event}
    payload.update(fields)
    with events_path(run_dir).open("a", encoding="utf-8") as handle:
        handle.write(json.dumps(payload, sort_keys=True, separators=(",", ":")) + "\n")


def load_status(run_dir: Path) -> dict[str, Any]:
    return read_json(status_path(run_dir))


def save_status(run_dir: Path, status: dict[str, Any], event: str | None = None, **event_fields: Any) -> None:
    write_json_atomic(status_path(run_dir), status)
    if event is not None:
        append_event(run_dir, event, **event_fields)


def mutate_status(
    run_dir: Path, mutator: Callable[[dict[str, Any]], tuple[dict[str, Any], str | None, dict[str, Any]]]
) -> int:
    with RunLock(run_dir):
        status = load_status(run_dir)
        updated, event, event_fields = mutator(status)
        if event is not None:
            save_status(run_dir, updated, event, **event_fields)
        else:
            write_json_atomic(status_path(run_dir), updated)
    return 0


def budget_exceeded(status: dict[str, Any]) -> bool:
    budget = status.get("budget", {})
    return (
        int(budget.get("wall_clock_s_used", 0)) > int(budget.get("wall_clock_s_max", 0))
        or int(budget.get("codex_rounds_used", 0)) > int(budget.get("codex_rounds_max", 0))
    )


def codex_rounds_ceiling(phases: int) -> int:
    """Dynamic per-attempt round ceiling, scaled to phase count.

    = 2*N (one round/phase + ~one debug-retry headroom) + 6 (CONVERGE_MAX_ROUNDS)
      + 4 (base slack). The flat 18 default was tuned for ~4-phase builds and
    false-capped a 14-phase build. Counter resets per attempt (see cmd_retry_start),
    so this bounds a SINGLE attempt's Codex usage — the real runaway signal.
    """
    return 2 * int(phases) + 10


def cmd_init(args: argparse.Namespace) -> int:
    run_dir = run_dir_path(args.run_dir)
    with RunLock(run_dir):
        now = utc_now()
        budget = dict(DEFAULT_BUDGET)
        phases = getattr(args, "phases", None)
        if phases:
            budget["codex_rounds_max"] = codex_rounds_ceiling(phases)
        status = {
            "run_id": run_dir.name,
            "issue": args.issue,
            "branch": args.branch,
            "worktree": str(Path(args.worktree).expanduser().resolve()),
            "spec": str(Path(args.spec).expanduser().resolve()),
            "state": "STARTED",
            "attempt": 1,
            "max_attempts": 3,
            "last_validated_commit": args.last_validated_commit,
            "created_at": iso_from_dt(now),
            "updated_at": iso_from_dt(now),
            "started_grace_until": iso_from_dt(now + timedelta(seconds=120)),
            "budget": budget,
            "last_failure_signature": None,
            "prior_failure_signatures": [],
            "last_retry_cause": None,
            "pr_url": None,
            "linear_projection_dirty": True,
            "linear_projected_at": None,
        }
        write_json_atomic(status_path(run_dir), status)
        events_path(run_dir).touch()
        append_event(run_dir, "init", state="STARTED", attempt=1)
    return 0


def cmd_transition(args: argparse.Namespace) -> int:
    run_dir = run_dir_path(args.run_dir)

    def mutate(status: dict[str, Any]) -> tuple[dict[str, Any], str, dict[str, Any]]:
        now = iso_now()
        status["state"] = args.state
        status["updated_at"] = now
        status["linear_projection_dirty"] = True
        if args.pr_url:
            status["pr_url"] = args.pr_url
        if args.failure_signature:
            status["last_failure_signature"] = args.failure_signature
        terminal_status: dict[str, Any] = {}
        if args.state == "CAPPED_NEEDS_HUMAN":
            attempt = int(status.get("attempt") or 1)
            terminal_path = run_dir / f"attempt-{attempt:03d}" / "terminal_status.json"
            try:
                terminal_status = read_json(terminal_path)
            except Exception:
                terminals = sorted(run_dir.glob("attempt-*/terminal_status.json"))
                if terminals:
                    try:
                        terminal_status = read_json(terminals[-1])
                    except Exception:
                        terminal_status = {}
        failure_reason = args.failure_reason or terminal_status.get("failure_reason")
        last_gate = args.last_gate or terminal_status.get("gate")
        last_validation_command = (
            args.last_validation_command or terminal_status.get("validation_command")
        )
        if failure_reason:
            status["failure_reason"] = failure_reason
        if last_gate:
            status["last_gate"] = last_gate
        if last_validation_command:
            status["last_validation_command"] = last_validation_command
        if args.state == "CONVERGED_PR_CREATED" and args.commit:
            status["last_validated_commit"] = args.commit
        fields = {"state": args.state}
        if args.pr_url:
            fields["pr_url"] = args.pr_url
        if args.commit:
            fields["commit"] = args.commit
        if args.failure_signature:
            fields["failure_signature"] = args.failure_signature
        if failure_reason:
            fields["failure_reason"] = failure_reason
        if last_gate:
            fields["last_gate"] = last_gate
        if last_validation_command:
            fields["last_validation_command"] = last_validation_command
        return status, "transition", fields

    return mutate_status(run_dir, mutate)


def cmd_set_pr_url(args: argparse.Namespace) -> int:
    """Stamp pr_url at SOURCE (REC-229). pr_url-ONLY write: sets pr_url +
    updated_at and emits a pr_url_stamped event. Deliberately does NOT touch
    `state` (so the reaper retains terminal ownership — R14) and does NOT touch
    `linear_projection_dirty` (introduces zero Linear-projection delta)."""
    run_dir = run_dir_path(args.run_dir)

    def mutate(status: dict[str, Any]) -> tuple[dict[str, Any], str, dict[str, Any]]:
        status["pr_url"] = args.pr_url
        status["updated_at"] = iso_now()
        return status, "pr_url_stamped", {"pr_url": args.pr_url}

    return mutate_status(run_dir, mutate)


def cmd_projection(args: argparse.Namespace) -> int:
    run_dir = run_dir_path(args.run_dir)

    def mutate(status: dict[str, Any]) -> tuple[dict[str, Any], str | None, dict[str, Any]]:
        if args.dirty:
            status["linear_projection_dirty"] = True
            status["updated_at"] = iso_now()
            return status, "projection_dirty", {"state": status.get("state")}
        if status.get("state") == args.expected_state:
            status["linear_projection_dirty"] = False
            status["linear_projected_at"] = iso_now()
            status["updated_at"] = iso_now()
            return status, "projection_clean", {"state": status.get("state")}
        return status, None, {}

    return mutate_status(run_dir, mutate)


def cmd_retry_start(args: argparse.Namespace) -> int:
    run_dir = run_dir_path(args.run_dir)
    try:
        with RunLock(run_dir):
            status = load_status(run_dir)
            next_attempt = int(status.get("attempt", 0)) + 1
            max_attempts = int(status.get("max_attempts", 0))
            if next_attempt > max_attempts:
                raise DispatchStatusError(
                    f"attempt cap exceeded: {next_attempt} > {max_attempts}"
                )
            (run_dir / f"attempt-{next_attempt:03d}").mkdir(parents=True, exist_ok=True)
            status["attempt"] = next_attempt
            status["state"] = "RETRY_PENDING"
            # Per-attempt reset: a retry is a fresh build from last_validated_commit,
            # so the round budget bounds THIS attempt's Codex usage, not cumulative
            # across attempts. Without this, a legit 2nd attempt false-caps on the
            # accumulated count (the 2026-06-06 REC-79 incident). MAX_RETRIES/phase
            # still bounds within-attempt churn.
            if isinstance(status.get("budget"), dict):
                status["budget"]["codex_rounds_used"] = 0
                status["budget"]["wall_clock_s_used"] = 0
            status["updated_at"] = iso_now()
            status["linear_projection_dirty"] = True
            if args.zombie:
                status["last_retry_cause"] = "zombie"
            else:
                status["last_retry_cause"] = None
                if args.failure_signature:
                    prior = list(status.get("prior_failure_signatures") or [])
                    prior.append(args.failure_signature)
                    status["prior_failure_signatures"] = prior
                    status["last_failure_signature"] = args.failure_signature
            save_status(
                run_dir,
                status,
                "retry_start",
                attempt=next_attempt,
                zombie=bool(args.zombie),
                failure_signature=None if args.zombie else args.failure_signature,
            )
    except DispatchStatusError as exc:
        print(str(exc), file=sys.stderr)
        return 1
    return 0


def normalize_signature_text(value: Any) -> str:
    text = "" if value is None else str(value)
    text = re.sub(
        r"\b\d{4}-\d{2}-\d{2}[T ][0-9:.]+(?:Z|[+-]\d{2}:?\d{2})?\b",
        "<timestamp>",
        text,
    )
    text = re.sub(r"\b\d+(?:\.\d+)?ms\b", "<duration>", text)
    text = re.sub(r"\b\d+(?:\.\d+)?s\b", "<duration>", text)
    text = re.sub(r"(?<!\S)/(?:tmp|var/folders)/[^\s\"']+", "<tmp-path>", text)
    text = re.sub(r"\bpid\s+\d+\b", "pid <pid>", text, flags=re.IGNORECASE)
    text = re.sub(r"\b0x[0-9a-fA-F]+\b", "<hex>", text)
    text = re.sub(r"\b\d+\s+tokens?\b", "<tokens>", text, flags=re.IGNORECASE)
    text = re.sub(r":\d+:", ":<line>:", text)
    return text.strip()


def failure_signature_from_terminal(terminal_status: dict[str, Any]) -> str:
    failing_ids = terminal_status.get("failing_test_ids") or []
    if not isinstance(failing_ids, list):
        failing_ids = [str(failing_ids)]
    sorted_ids = sorted(str(item) for item in failing_ids)
    fields = [
        normalize_signature_text(terminal_status.get("phase")),
        normalize_signature_text(terminal_status.get("gate")),
        normalize_signature_text(terminal_status.get("validation_command")),
        normalize_signature_text(terminal_status.get("converge_status")),
        json.dumps(sorted_ids, ensure_ascii=True, separators=(",", ":")),
        normalize_signature_text(terminal_status.get("convergence_verdict_summary")),
    ]
    digest = hashlib.sha256("\n".join(fields).encode("utf-8")).hexdigest()
    return f"sha256:{digest}"


def cmd_signature(args: argparse.Namespace) -> int:
    terminal_status = read_json(Path(args.terminal_status).expanduser())
    print(failure_signature_from_terminal(terminal_status))
    return 0


def terminal_failure_signature(terminal_status: dict[str, Any]) -> str:
    signature = terminal_status.get("failure_signature")
    if isinstance(signature, str) and signature.startswith("sha256:"):
        return signature
    return failure_signature_from_terminal(terminal_status)


def cmd_classify(args: argparse.Namespace) -> int:
    run_dir = run_dir_path(args.run_dir)
    status = load_status(run_dir)
    terminal_status = read_json(Path(args.terminal_status).expanduser())
    if (
        terminal_status.get("converge_status") == "CONVERGED"
        and int(terminal_status.get("exit_code", 1)) == 0
    ):
        print("CONVERGED")
        return 0
    cause_hint = terminal_status.get("cause_hint")
    signature = terminal_failure_signature(terminal_status)
    prior = set(status.get("prior_failure_signatures") or [])
    if cause_hint in DETERMINISTIC_CAUSES or signature in prior:
        print("DETERMINISTIC")
        return 0
    if cause_hint == "budget_exceeded" or budget_exceeded(status):
        print("CAPPED_BUDGET")
        return 0
    print("TRANSIENT")
    return 0


def cmd_budget_check(args: argparse.Namespace) -> int:
    run_dir = run_dir_path(args.run_dir)
    exceeded = False

    with RunLock(run_dir):
        status = load_status(run_dir)
        budget = dict(status.get("budget") or {})
        if args.set_wall_used is not None:
            budget["wall_clock_s_used"] = int(args.set_wall_used)
        if args.add_rounds is not None:
            budget["codex_rounds_used"] = int(budget.get("codex_rounds_used", 0)) + int(
                args.add_rounds
            )
        status["budget"] = budget
        status["updated_at"] = iso_now()
        exceeded = budget_exceeded(status)
        save_status(
            run_dir,
            status,
            "budget_check",
            wall_clock_s_used=budget.get("wall_clock_s_used"),
            codex_rounds_used=budget.get("codex_rounds_used"),
            exceeded=exceeded,
        )
    return 1 if exceeded else 0


def cmd_concurrency_check(args: argparse.Namespace) -> int:
    """Count non-terminal dispatch runs vs CONCURRENT_RUNS_MAX (the SSOT cap).

    Prints 'CAP <n>/<max>' (exit 1) if at/over the cap, else 'OK <n>/<max>' (exit 0).
    The /dispatch skill calls this instead of hardcoding the number — so the cap
    lives in exactly one place (this module), fixing the prior unwired-field SSOT gap.
    """
    runs_root = Path(args.runs_root).expanduser() if args.runs_root else (
        Path.home() / ".recoil/dispatch-runs"
    )
    terminal = TERMINAL_STATES
    n = 0
    for sp in sorted(runs_root.glob("*/status.json")):
        try:
            st = json.loads(sp.read_text()).get("state", "")
        except Exception:
            continue
        if st not in terminal:
            n += 1
    capped = n >= CONCURRENT_RUNS_MAX
    print(f"{'CAP' if capped else 'OK'} {n}/{CONCURRENT_RUNS_MAX}")
    return 1 if capped else 0


def build_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(description=__doc__)
    subparsers = parser.add_subparsers(dest="command", required=True)

    concurrency = subparsers.add_parser("concurrency-check")
    concurrency.add_argument("--runs-root", default=None)
    concurrency.set_defaults(func=cmd_concurrency_check)

    init = subparsers.add_parser("init")
    init.add_argument("--run-dir", required=True)
    init.add_argument("--issue", required=True)
    init.add_argument("--branch", required=True)
    init.add_argument("--worktree", required=True)
    init.add_argument("--spec", required=True)
    init.add_argument("--last-validated-commit", required=True)
    init.add_argument(
        "--phases",
        type=int,
        default=None,
        help="Phase count from the BUILD_SPEC; sets a dynamic codex_rounds_max "
        "(2*N+10). Omit to keep the flat 18 default.",
    )
    init.set_defaults(func=cmd_init)

    transition = subparsers.add_parser("transition")
    transition.add_argument("--run-dir", required=True)
    transition.add_argument("--state", required=True)
    transition.add_argument("--pr-url")
    transition.add_argument("--commit")
    transition.add_argument("--failure-signature")
    transition.add_argument("--failure-reason")
    transition.add_argument("--last-gate")
    transition.add_argument("--last-validation-command")
    transition.set_defaults(func=cmd_transition)

    set_pr_url = subparsers.add_parser("set-pr-url")
    set_pr_url.add_argument("--run-dir", required=True)
    set_pr_url.add_argument("--pr-url", required=True)
    set_pr_url.set_defaults(func=cmd_set_pr_url)

    projection = subparsers.add_parser("projection")
    projection.add_argument("--run-dir", required=True)
    projection_group = projection.add_mutually_exclusive_group(required=True)
    projection_group.add_argument("--dirty", action="store_true")
    projection_group.add_argument("--clean", action="store_true")
    projection.add_argument("--expected-state")
    projection.set_defaults(func=cmd_projection)

    retry = subparsers.add_parser("retry-start")
    retry.add_argument("--run-dir", required=True)
    retry.add_argument("--failure-signature")
    retry.add_argument("--zombie", action="store_true")
    retry.set_defaults(func=cmd_retry_start)

    signature = subparsers.add_parser("signature")
    signature.add_argument("--terminal-status", required=True)
    signature.set_defaults(func=cmd_signature)

    classify = subparsers.add_parser("classify")
    classify.add_argument("--run-dir", required=True)
    classify.add_argument("--terminal-status", required=True)
    classify.set_defaults(func=cmd_classify)

    budget = subparsers.add_parser("budget-check")
    budget.add_argument("--run-dir", required=True)
    budget.add_argument("--set-wall-used", type=int)
    budget.add_argument("--add-rounds", type=int)
    budget.set_defaults(func=cmd_budget_check)

    return parser


def main(argv: list[str] | None = None) -> int:
    args = build_parser().parse_args(argv)
    if args.command == "projection" and args.clean and not args.expected_state:
        print("projection --clean requires --expected-state", file=sys.stderr)
        return 2
    try:
        return int(args.func(args))
    except DispatchStatusError as exc:
        print(str(exc), file=sys.stderr)
        return 1


if __name__ == "__main__":
    raise SystemExit(main())
