#!/usr/bin/env python3
"""orchestrate_guard.py — the deterministic heart of /orchestrate.

Owns ONLY reproducible decisions: round-accounting, the stop-rule over structured
findings, dispatch/merge/spend gate-checks (for the hooks), and the event-sourced
state SSOT with atomic, hash-chained, flock-protected writes + rewind.

It NEVER reads reviewer prose. Semantic classification (a finding's `kind`) is made
by the gpt-5.5 reviewer and captured in findings.json BEFORE the guard sees it; the
guard does arithmetic over that structure. Given the same ORCH_LOG.jsonl + the same
round findings, the guard returns the same decision every time.

Design SSOT: consultations/tooling/orchestrate-engine-2026-06-18/SYNTHESIS.md
Hardened per guard_review_codex.md (2026-06-18): iterative rewind replay, strict
findings schema + normalized surfaces, parsed-UTC spend gate, recovery-aware
record-round, explicit stop-code priority, lock-free reads.

Exit codes:
  0  = READY / advance allowed
  10 = CONTINUE (autonomous fix allowed)
  20 = STOP (human gate required) — reason in stdout JSON
  30 = ERROR / corrupt-or-ambiguous state -> safe-default STOP
  2  = gate-check BLOCK (PreToolUse hooks use this)
"""
from __future__ import annotations

import argparse
import contextlib
import datetime as _dt
import fcntl
import hashlib
import json
import os
import re
import sys
from pathlib import Path
from typing import Any, Optional

SCHEMA_VERSION = 1
SPEC_MAX_ROUNDS = 5
ULTRA_MAX_ROUNDS = 1

EXIT_READY = 0
EXIT_CONTINUE = 10
EXIT_STOP_HUMAN = 20
EXIT_ERROR = 30
EXIT_GATE_BLOCK = 2

_CRIT = {"CRITICAL", "HIGH"}
# Hard human-stop kinds = GENUINE product decisions only: a build-BOUNDARY call (scope:
# does this work belong in THIS build) or something needing info/authority/spend/policy/
# threshold (needs-human), or an unclassifiable finding (ambiguous, fail-closed to human).
# `design` (an architecture FORK with a grounded default) is NOT here (JT 2026-06-23,
# dogfood): the conductor resolves it by grounding in the live code + recording a REVERSIBLE
# DECISION fork, then re-gates. Genuine non-convergence of design findings is still caught by
# the whack-a-mole (same surface re-fails), severity-stall (criticals not strictly decreasing),
# and round-cap guards — which rule (1) used to preempt, making detailed specs un-convergeable.
_HUMAN_KINDS = {"scope", "needs-human", "needs_human", "ambiguous"}
_ZERO_HASH = "sha256:" + "0" * 64


class GuardError(Exception):
    """Any guard-domain failure -> caller fail-closes to STOP/BLOCK."""


# ----------------------------------------------------------------------------- io
def _now() -> _dt.datetime:
    return _dt.datetime.now(_dt.timezone.utc)


def _now_iso() -> str:
    return _now().strftime("%Y-%m-%dT%H:%M:%SZ")


def _parse_iso(s: Any) -> Optional[_dt.datetime]:
    """Strict UTC ISO parse. Returns None on anything that isn't a real timestamp
    (so a money/expiry gate can never be fooled by a lexicographically-large string)."""
    if not isinstance(s, str):
        return None
    if not re.fullmatch(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z", s):
        return None
    try:
        return _dt.datetime.strptime(s, "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=_dt.timezone.utc)
    except ValueError:
        return None


def _canonical(obj: Any) -> str:
    return json.dumps(obj, sort_keys=True, separators=(",", ":"), ensure_ascii=False)


def _sha256(s: str) -> str:
    return "sha256:" + hashlib.sha256(s.encode("utf-8")).hexdigest()


@contextlib.contextmanager
def _locked(run_dir: Path):
    """Single-writer flock over the run. All MUTATING ops hold this."""
    lock_path = run_dir / "lock"
    lock_path.touch(exist_ok=True)
    fh = open(lock_path, "r+")
    try:
        fcntl.flock(fh.fileno(), fcntl.LOCK_EX)
        yield
    finally:
        fcntl.flock(fh.fileno(), fcntl.LOCK_UN)
        fh.close()


def _log_path(run_dir: Path) -> Path:
    return run_dir / "ORCH_LOG.jsonl"


def _state_path(run_dir: Path) -> Path:
    return run_dir / "ORCH_STATE.json"


def _read_log(run_dir: Path) -> list[dict]:
    p = _log_path(run_dir)
    if not p.exists():
        return []
    events = []
    for ln, line in enumerate(p.read_text(encoding="utf-8").splitlines(), start=1):
        line = line.strip()
        if not line:
            continue
        try:
            events.append(json.loads(line))
        except json.JSONDecodeError as e:
            raise GuardError(f"corrupt log at line {ln}: {e}")
    return events


def _verify_chain(events: list[dict]) -> None:
    prev = _ZERO_HASH
    for i, ev in enumerate(events):
        if ev.get("prev_hash") != prev:
            raise GuardError(f"hash-chain break at seq {ev.get('seq')} (index {i})")
        body = {k: v for k, v in ev.items() if k != "event_hash"}
        if ev.get("event_hash") != _sha256(_canonical(body)):
            raise GuardError(f"event_hash mismatch at seq {ev.get('seq')}")
        prev = ev["event_hash"]


def _append_event(run_dir: Path, event: dict, events: Optional[list[dict]] = None) -> tuple[dict, list[dict]]:
    """Append a hash-chained event (under _locked). Returns (event, full_events).
    Pass `events` (already-read+verified) to avoid a re-read race."""
    if events is None:
        events = _read_log(run_dir)
        _verify_chain(events)
    seq = (events[-1]["seq"] + 1) if events else 0
    prev_hash = events[-1]["event_hash"] if events else _ZERO_HASH
    ev = dict(event)
    ev["seq"] = seq
    ev["ts"] = _now_iso()
    ev["prev_hash"] = prev_hash
    ev["event_hash"] = _sha256(_canonical(ev))
    with open(_log_path(run_dir), "a", encoding="utf-8") as fh:  # linearization point
        fh.write(json.dumps(ev, ensure_ascii=False) + "\n")
        fh.flush()
        os.fsync(fh.fileno())
    return ev, events + [ev]


# --------------------------------------------------------------------- materialize
def _initial_state() -> dict:
    return {
        "schema_version": SCHEMA_VERSION,
        "run_id": None, "target": None, "repo": None, "issue": None, "worktree": None,
        "state": "TARGET_INTAKE",
        "spec_round": 0, "ultra_round": 0,
        "fixed_surfaces": {}, "prev_spec_crit": None,
        "spec_selfgate": "PENDING",
        "pending_gate": None, "last_stop_code": None,
        "verify_passed_by_human": False,
        "pr_url": None, "base_commit": None, "last_validated_commit": None,
        "spend_authorized_until": None,
        "state_seq": -1, "prev_hash": _ZERO_HASH, "updated_at": None,
    }


def _apply(state: dict, ev: dict) -> dict:
    """Pure reducer for a NON-rewind event. Rewind is resolved upstream by
    _effective_events (which removes superseded events), so this never sees REWIND."""
    t = ev.get("type")
    p = ev.get("payload", {})
    if t == "INIT":
        for k in ("run_id", "target", "repo", "issue", "worktree", "base_commit"):
            if k in p:
                state[k] = p[k]
        state["state"] = "AUTHOR_SPEC"
    elif t == "TRANSITION":
        state["state"] = p["to_state"]
        for k in ("pending_gate", "spec_selfgate", "pr_url", "verify_passed_by_human",
                  "spend_authorized_until"):
            if k in p:
                state[k] = p[k]
        if p.get("validated_commit"):
            state["last_validated_commit"] = p["validated_commit"]
    elif t == "ROUND_RECORDED":
        if p["loop"] == "spec":
            state["spec_round"] = p["round"]
            state["prev_spec_crit"] = p["crit_count"]
            for k in p.get("new_fixed_surfaces", []):
                state["fixed_surfaces"][k] = state["fixed_surfaces"].get(k, 0) + 1
            state["spec_selfgate"] = "READY" if p["verdict_ready"] else "NEEDS-FIXES"
        else:
            state["ultra_round"] = p["round"]
        if p.get("pending_gate") is not None:
            state["pending_gate"] = p["pending_gate"]
        if p.get("stop_code") is not None:
            state["last_stop_code"] = p["stop_code"]
    elif t == "DISPATCH_SYNC":
        if p.get("pr_url"):
            state["pr_url"] = p["pr_url"]
        if p.get("to_state"):
            state["state"] = p["to_state"]
        if p.get("pending_gate") is not None:
            state["pending_gate"] = p["pending_gate"]
    return state


def _effective_events(events: list[dict]) -> list[dict]:
    """Resolve REWINDs ITERATIVELY into the effective (live) non-rewind event stream.
    A REWIND(to_seq=T) drops every live event with seq > T that precedes it. Total +
    deterministic for ANY accepted log; no recursion, no rewind-to-rewind ambiguity
    (rewind targets are validated to be non-REWIND at write time)."""
    live: list[dict] = []
    for ev in events:
        if ev.get("type") == "REWIND":
            t = ev["payload"]["to_seq"]
            live = [e for e in live if e["seq"] <= t]
        else:
            live.append(ev)
    return live


def _materialize(events: list[dict]) -> dict:
    state = _initial_state()
    for ev in _effective_events(events):
        state = _apply(state, ev)
    if events:  # chain head from the LAST real log event (incl. rewinds) for append continuity
        state["state_seq"] = events[-1]["seq"]
        state["prev_hash"] = events[-1]["event_hash"]
        state["updated_at"] = events[-1]["ts"]
    return state


def _write_snapshot(run_dir: Path, state: dict) -> None:
    tmp = _state_path(run_dir).with_suffix(".json.tmp")
    tmp.write_text(json.dumps(state, indent=2, ensure_ascii=False), encoding="utf-8")
    os.replace(tmp, _state_path(run_dir))


def _materialize_from(run_dir: Path, events: Optional[list[dict]] = None) -> dict:
    if events is None:
        events = _read_log(run_dir)
        _verify_chain(events)
    return _materialize(events)


def _load_state(run_dir: Path, *, write_snapshot: bool = True) -> dict:
    """Rebuild from the log (SSOT). Log wins. Reads pass write_snapshot=False to avoid
    a tmp-file race with no lock held."""
    state = _materialize_from(run_dir)
    if write_snapshot:
        _write_snapshot(run_dir, state)
    return state


# ------------------------------------------------------------------------ stop-rule
def _norm_surface(f: dict) -> str:
    """Normalized file:surface key. Raises GuardError if either is empty (a crit/high
    finding MUST name a real surface — fail closed, never collapse to '?:?')."""
    file = str(f.get("file", "")).strip()
    surf = str(f.get("surface", "")).strip()
    if not file or not surf:
        raise GuardError(f"critical/high finding missing file/surface: {f!r}")
    file = re.sub(r"^\./", "", file.replace("\\", "/"))
    surf = re.sub(r"\s+", " ", surf).casefold()
    return f"{file}:{surf}"


def _load_findings(path: Path) -> dict:
    try:
        data = json.loads(Path(path).read_text(encoding="utf-8"))
    except (json.JSONDecodeError, OSError) as e:
        raise GuardError(f"findings.json unreadable: {e}")
    if not isinstance(data, dict) or not isinstance(data.get("findings"), list):
        raise GuardError("findings.json must be an object with a 'findings' list")
    for f in data["findings"]:
        if not isinstance(f, dict):
            raise GuardError("each finding must be an object")
    return data


def _evaluate(state: dict, findings_doc: dict, verdict_ready: bool, loop: str) -> dict:
    """The mechanical stop-rule. Explicit priority:
    human-kind(scope/needs-human) > drift-guard > whack-a-mole > stall > round-cap > (READY|CONTINUE).
    `design` findings are NOT a hard stop — they ride the whack-a-mole/stall/round-cap path so the
    conductor can auto-converge them (with a recorded reversible decision). Raises GuardError on
    malformed crit findings (fail-closed upstream)."""
    findings = findings_doc.get("findings", [])
    crit = [f for f in findings if str(f.get("severity", "")).upper() in _CRIT]
    surfaces = [_norm_surface(f) for f in crit]  # raises on missing file/surface
    round_no = (state["spec_round"] if loop == "spec" else state["ultra_round"]) + 1

    def stop(code, reason, **extra):
        return {"decision": "STOP", "exit": EXIT_STOP_HUMAN, "round": round_no,
                "crit_count": len(crit), "new_fixed_surfaces": surfaces,
                "stop_code": code, "reason": reason, **extra}

    # (1) Genuine-product-decision finding -> human gate (highest priority). Only scope /
    # needs-human / ambiguous (NOT design — the conductor auto-resolves those, see (4)-(6)).
    human = [f for f in findings if str(f.get("kind", "")).lower() in _HUMAN_KINDS]
    if human:
        return stop("STOP_SCOPE", "human-domain finding", human_findings=human)
    # (2) Dual-guard drift: prose READY but JSON has crit/high (preserves REC-178 catch).
    if verdict_ready and crit:
        return stop("DRIFT_GUARD", "prose READY but findings.json has CRITICAL/HIGH")
    # (3) Clean READY: prose READY AND zero crit/high.
    if verdict_ready and not crit:
        return {"decision": "READY", "exit": EXIT_READY, "round": round_no,
                "crit_count": 0, "new_fixed_surfaces": [], "stop_code": None,
                "reason": "READY"}
    # (4) Whack-a-mole: a CRITICAL on a surface already fixed once before.
    for s in surfaces:
        if state["fixed_surfaces"].get(s, 0) >= 1:
            return stop("STOP_WHACKAMOLE", f"new critical on already-fixed surface {s}",
                        surface=s, recommend="SIMPLIFY_SPEC_FAIL_LOUD")
    cap = SPEC_MAX_ROUNDS if loop == "spec" else ULTRA_MAX_ROUNDS
    # (5) Round cap (checked before stall so a hard cap breach reports STOP_ROUNDS).
    if round_no > cap:
        return stop("STOP_ROUNDS", f"{loop} round cap {cap} exceeded")
    # (6) Severity stall: spec criticals not strictly decreasing from round >= 2.
    prev = state["prev_spec_crit"]
    if loop == "spec" and round_no >= 2 and prev is not None and len(crit) >= prev:
        return stop("STOP_STALL", f"criticals not decreasing ({prev} -> {len(crit)})")
    # (7) Progress -> autonomous fix allowed.
    return {"decision": "CONTINUE", "exit": EXIT_CONTINUE, "round": round_no,
            "crit_count": len(crit), "new_fixed_surfaces": surfaces, "stop_code": None,
            "reason": "progress: apply fixes and re-gate"}


# Map a stop-code to the human gate the skill should surface.
_STOP_GATE = {
    "STOP_SCOPE": "AWAIT_HUMAN_SCOPE", "STOP_WHACKAMOLE": "AWAIT_HUMAN_SCOPE",
    "STOP_STALL": "AWAIT_HUMAN_SCOPE", "STOP_ROUNDS": "AWAIT_HUMAN_SCOPE",
    "DRIFT_GUARD": "AWAIT_HUMAN_SCOPE",
}


# ------------------------------------------------------------------------ commands
def cmd_init(args) -> int:
    run_dir = Path(args.run_dir)
    run_dir.mkdir(parents=True, exist_ok=True)
    (run_dir / "artifacts").mkdir(exist_ok=True)
    try:
        with _locked(run_dir):
            if _read_log(run_dir):
                print(json.dumps({"error": "run already initialized"}))
                return EXIT_ERROR
            _, events = _append_event(run_dir, {"type": "INIT", "actor": "orchestrate", "payload": {
                "run_id": args.run_id, "target": args.target, "repo": args.repo,
                "issue": args.issue, "worktree": args.worktree, "base_commit": args.base_commit}})
            state = _materialize(events)
            _write_snapshot(run_dir, state)
    except GuardError as e:
        print(json.dumps({"error": str(e)}))
        return EXIT_ERROR
    print(json.dumps({"ok": True, "state": state["state"], "run_dir": str(run_dir)}))
    return EXIT_READY


def cmd_record_round(args) -> int:
    run_dir = Path(args.run_dir)
    try:
        with _locked(run_dir):
            events = _read_log(run_dir)
            _verify_chain(events)
            state = _materialize(events)
            findings_doc = _load_findings(Path(args.findings))
            verdict_ready = bool(args.verdict_ready)
            decision = _evaluate(state, findings_doc, verdict_ready, args.loop)
            pending = _STOP_GATE.get(decision.get("stop_code")) if decision["decision"] == "STOP" else None
            _, events2 = _append_event(run_dir, {"type": "ROUND_RECORDED", "actor": "guard", "payload": {
                "loop": args.loop, "round": decision["round"], "crit_count": decision["crit_count"],
                "new_fixed_surfaces": decision["new_fixed_surfaces"],
                "verdict_ready": decision["decision"] == "READY",
                "pending_gate": pending, "stop_code": decision.get("stop_code"),
                "decision": decision["decision"], "findings_ref": str(args.findings)}}, events=events)
            # Recovery-aware: round IS recorded. Snapshot refresh is best-effort.
            try:
                _write_snapshot(run_dir, _materialize(events2))
            except OSError as e:
                decision["snapshot_warning"] = f"round recorded; snapshot refresh failed: {e}"
    except GuardError as e:
        print(json.dumps({"decision": "STOP", "reason": f"guard error: {e}",
                          "stop_code": "GUARD_ERROR", "exit": EXIT_ERROR}))
        return EXIT_ERROR
    print(json.dumps(decision))
    return decision["exit"]


def cmd_transition(args) -> int:
    run_dir = Path(args.run_dir)
    payload: dict[str, Any] = {"to_state": args.to_state}
    for k in ("git_head", "validated_commit", "pending_gate", "spec_selfgate", "pr_url"):
        v = getattr(args, k.replace("-", "_"), None)
        if v is not None:
            payload[k] = v
    # Explicit clear: the ONLY way to express "human gate resolved -> no pending gate"
    # (omission means leave-unchanged). Clear wins if both are passed.
    if args.clear_pending_gate:
        payload["pending_gate"] = None
    if args.spend_authorized_until is not None:
        if _parse_iso(args.spend_authorized_until) is None:
            print(json.dumps({"error": "spend-authorized-until must be UTC ISO YYYY-MM-DDTHH:MM:SSZ"}))
            return EXIT_ERROR
        payload["spend_authorized_until"] = args.spend_authorized_until
    if args.verify_passed:
        payload["verify_passed_by_human"] = True
    try:
        with _locked(run_dir):
            events = _read_log(run_dir)
            _verify_chain(events)
            _, events2 = _append_event(run_dir, {"type": "TRANSITION", "actor": args.actor,
                                                 "payload": payload}, events=events)
            state = _materialize(events2)
            _write_snapshot(run_dir, state)
    except GuardError as e:
        print(json.dumps({"error": str(e)}))
        return EXIT_ERROR
    print(json.dumps({"ok": True, "state": state["state"]}))
    return EXIT_READY


def cmd_rewind(args) -> int:
    run_dir = Path(args.run_dir)
    try:
        with _locked(run_dir):
            events = _read_log(run_dir)
            _verify_chain(events)
            target = next((e for e in events if e["seq"] == args.to_seq), None)
            if target is None:
                print(json.dumps({"error": f"seq {args.to_seq} not in log"}))
                return EXIT_ERROR
            if target.get("type") == "REWIND":
                print(json.dumps({"error": "cannot rewind to a REWIND event"}))
                return EXIT_ERROR
            # Target must be LIVE in the effective stream — rewinding to an event a
            # prior rewind already dropped would be a silent no-op (misleading).
            if args.to_seq not in {e["seq"] for e in _effective_events(events)}:
                print(json.dumps({"error": f"seq {args.to_seq} is not live (already superseded by a prior rewind)"}))
                return EXIT_ERROR
            _, events2 = _append_event(run_dir, {"type": "REWIND", "actor": "orchestrate",
                                                 "payload": {"to_seq": args.to_seq, "reason": args.reason}},
                                       events=events)
            state = _materialize(events2)
            _write_snapshot(run_dir, state)
    except GuardError as e:
        print(json.dumps({"error": str(e)}))
        return EXIT_ERROR
    print(json.dumps({"ok": True, "rewound_to": args.to_seq, "state": state["state"]}))
    return EXIT_READY


def cmd_status(args) -> int:
    try:
        state = _load_state(Path(args.run_dir), write_snapshot=False)
    except GuardError as e:
        print(json.dumps({"error": f"corrupt state: {e}"}))
        return EXIT_ERROR
    print(json.dumps(state, indent=2))
    return EXIT_READY


def cmd_gate_check(args) -> int:
    """PreToolUse hooks. 0=allow, 2=BLOCK (reason on stderr). Fail-closed."""
    try:
        state = _load_state(Path(args.run_dir), write_snapshot=False)
    except Exception as e:  # fail-closed on anything
        sys.stderr.write(f"orchestrate gate-check: unreadable state -> BLOCK ({e})\n")
        return EXIT_GATE_BLOCK
    if args.gate == "dispatch":
        if state.get("spec_selfgate") != "READY":
            sys.stderr.write(f"BLOCK dispatch: spec_selfgate={state.get('spec_selfgate')} (Step 9.9 not satisfied)\n")
            return EXIT_GATE_BLOCK
        pg = state.get("pending_gate")
        if pg not in (None, "AWAIT_SPEND"):  # a human scope/verify/merge gate is open
            sys.stderr.write(f"BLOCK dispatch: pending human gate {pg}\n")
            return EXIT_GATE_BLOCK
        return EXIT_READY
    if args.gate == "merge":
        if not (state.get("pending_gate") == "AWAIT_MERGE" and state.get("verify_passed_by_human")):
            sys.stderr.write("BLOCK merge: merge gate unmet — live-verify not recorded "
                             "(need AWAIT_MERGE + verify_passed). Converge + verify before merging.\n")
            return EXIT_GATE_BLOCK
        return EXIT_READY
    if args.gate == "spend":
        until = _parse_iso(state.get("spend_authorized_until"))
        if until is None or until < _now():
            sys.stderr.write("BLOCK spend: no valid standing authorization\n")
            return EXIT_GATE_BLOCK
        return EXIT_READY
    sys.stderr.write(f"BLOCK: unknown gate {args.gate}\n")
    return EXIT_GATE_BLOCK


def main(argv: Optional[list[str]] = None) -> int:
    ap = argparse.ArgumentParser(description="orchestrate guard — deterministic state + stop-rule")
    sub = ap.add_subparsers(dest="cmd", required=True)

    p = sub.add_parser("init"); p.add_argument("--run-dir", required=True)
    p.add_argument("--run-id", required=True); p.add_argument("--target", default="")
    p.add_argument("--repo", required=True); p.add_argument("--issue", default="")
    p.add_argument("--worktree", default=""); p.add_argument("--base-commit", default="")
    p.set_defaults(func=cmd_init)

    p = sub.add_parser("record-round"); p.add_argument("--run-dir", required=True)
    p.add_argument("--loop", choices=["spec", "ultra"], required=True)
    p.add_argument("--findings", required=True); p.add_argument("--verdict-ready", action="store_true")
    p.set_defaults(func=cmd_record_round)

    p = sub.add_parser("transition"); p.add_argument("--run-dir", required=True)
    p.add_argument("--to-state", required=True); p.add_argument("--actor", default="orchestrate")
    p.add_argument("--git-head"); p.add_argument("--validated-commit")
    p.add_argument("--pending-gate"); p.add_argument("--clear-pending-gate", action="store_true")
    p.add_argument("--spec-selfgate")
    p.add_argument("--pr-url"); p.add_argument("--spend-authorized-until")
    p.add_argument("--verify-passed", action="store_true")
    p.set_defaults(func=cmd_transition)

    p = sub.add_parser("rewind"); p.add_argument("--run-dir", required=True)
    p.add_argument("--to-seq", type=int, required=True); p.add_argument("--reason", default="")
    p.set_defaults(func=cmd_rewind)

    p = sub.add_parser("status"); p.add_argument("--run-dir", required=True)
    p.set_defaults(func=cmd_status)

    p = sub.add_parser("gate-check"); p.add_argument("--run-dir", required=True)
    p.add_argument("--gate", choices=["dispatch", "merge", "spend"], required=True)
    p.set_defaults(func=cmd_gate_check)

    args = ap.parse_args(argv)
    return args.func(args)


if __name__ == "__main__":
    raise SystemExit(main())
