#!/usr/bin/env python3
"""Global autonomy lease for Studio-local autonomous runs."""

from __future__ import annotations

import argparse
import json
import os
from pathlib import Path
import socket
import subprocess
import sys
import tempfile
from datetime import datetime, timedelta, timezone
from typing import Any

from recoil.pipeline.tools.autonomy import constants


STATE_DIR = constants.STATE_DIR
LEASE_PATH = constants.LEASE_PATH
LEASE_LOCK = constants.LEASE_LOCK
SHLOCK = constants.SHLOCK

VALID_MODES = {"tick", "build", "interactive"}


def _utc_now() -> datetime:
    return datetime.now(timezone.utc)


def _iso(value: datetime) -> str:
    return value.replace(microsecond=0).isoformat().replace("+00:00", "Z")


def _parse_iso(value: object) -> datetime | None:
    if not isinstance(value, str) or not value:
        return None
    try:
        parsed = datetime.fromisoformat(value.replace("Z", "+00:00"))
    except ValueError:
        return None
    if parsed.tzinfo is None:
        return parsed.replace(tzinfo=timezone.utc)
    return parsed.astimezone(timezone.utc)


def _atomic_write_json(path: Path, obj: dict[str, Any]) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    fd, tmp_name = tempfile.mkstemp(
        prefix=f".{path.name}.", suffix=".tmp", dir=str(path.parent)
    )
    tmp_path = Path(tmp_name)
    try:
        with os.fdopen(fd, "w", encoding="utf-8") as handle:
            json.dump(obj, handle, indent=2, sort_keys=True)
            handle.write("\n")
            handle.flush()
            os.fsync(handle.fileno())
        os.replace(tmp_path, path)
    finally:
        if tmp_path.exists():
            tmp_path.unlink()


def _read_json(path: Path) -> dict[str, Any] | None:
    try:
        with path.open("r", encoding="utf-8") as handle:
            data = json.load(handle)
    except FileNotFoundError:
        return None
    except (OSError, json.JSONDecodeError):
        raise
    if not isinstance(data, dict):
        raise ValueError(f"{path} must contain a JSON object")
    return data


def read() -> dict[str, Any] | None:
    try:
        return _read_json(LEASE_PATH)
    except (OSError, ValueError, json.JSONDecodeError):
        return None


def _acquire_shlock(pid: int) -> bool:
    STATE_DIR.mkdir(parents=True, exist_ok=True)
    try:
        result = subprocess.run(
            [SHLOCK, "-p", str(pid), "-f", str(LEASE_LOCK)],
            stdout=subprocess.DEVNULL,
            stderr=subprocess.DEVNULL,
            check=False,
        )
    except OSError:
        return False
    return result.returncode == 0


def _unlink_lock() -> None:
    try:
        LEASE_LOCK.unlink()
    except FileNotFoundError:
        pass


def _pid_is_dead(pid: object) -> bool:
    try:
        pid_int = int(pid)
    except (TypeError, ValueError):
        return False
    if pid_int <= 0:
        return False
    try:
        os.kill(pid_int, 0)
    except ProcessLookupError:
        return True
    except OSError:
        return False
    return False


def _tmux_session_gone(tmux_session: object) -> bool:
    if not tmux_session:
        return True

    # The Phase 1 boundary forbids shelling out to tmux. The only unambiguous
    # local proof available here is absence of this user's tmux socket root.
    socket_root = Path(tempfile.gettempdir()) / f"tmux-{os.getuid()}"
    return not socket_root.exists()


def _stale_reclaim_allowed(existing: dict[str, Any], now: datetime) -> bool:
    expires_at = _parse_iso(existing.get("expires_at"))
    if expires_at is None or not expires_at < now:
        return False
    return _pid_is_dead(existing.get("pid")) and _tmux_session_gone(
        existing.get("tmux_session")
    )


def _holder(mode: str, host: str, pid: int, tmux_session: str | None) -> str:
    parts = [host, str(pid), mode]
    if tmux_session:
        parts.append(tmux_session)
    return ":".join(parts)


def _lease_record(
    mode: str,
    *,
    run_id: str,
    issue_id: str | None,
    ttl: int,
    host: str,
    pid: int,
    tmux_session: str | None,
    worktree_path: str | None,
    now: datetime,
) -> dict[str, Any]:
    now_iso = _iso(now)
    return {
        "schema_version": 1,
        "mode": mode,
        "holder": _holder(mode, host, pid, tmux_session),
        "run_id": run_id,
        "issue_id": issue_id,
        "host": host,
        "pid": pid,
        "tmux_session": tmux_session,
        "worktree_path": worktree_path,
        "acquired_at": now_iso,
        "heartbeat_at": now_iso,
        "expires_at": _iso(now + timedelta(seconds=int(ttl))),
        "night_id": constants.current_night_id(now),
    }


def acquire(
    mode: str,
    *,
    run_id: str,
    issue_id: str | None = None,
    ttl: int,
    host: str | None = None,
    pid: int | None = None,
    tmux_session: str | None = None,
    worktree_path: str | None = None,
) -> dict[str, Any] | None:
    if mode not in VALID_MODES:
        raise ValueError(f"invalid lease mode: {mode}")

    lock_pid = int(pid if pid is not None else os.getpid())
    if not _acquire_shlock(lock_pid):
        return None

    now = _utc_now()
    try:
        try:
            existing = _read_json(LEASE_PATH)
        except (OSError, ValueError, json.JSONDecodeError):
            _unlink_lock()
            return None

        if existing is not None:
            expires_at = _parse_iso(existing.get("expires_at"))
            if expires_at is not None and expires_at > now:
                if existing.get("run_id") != run_id:
                    _unlink_lock()
                    return None
            elif not _stale_reclaim_allowed(existing, now):
                _unlink_lock()
                return None

        record = _lease_record(
            mode,
            run_id=run_id,
            issue_id=issue_id,
            ttl=int(ttl),
            host=host or socket.gethostname(),
            pid=lock_pid,
            tmux_session=tmux_session,
            worktree_path=worktree_path,
            now=now,
        )
        _atomic_write_json(LEASE_PATH, record)
        return record
    except Exception:
        _unlink_lock()
        raise


def heartbeat(run_id: str, *, ttl: int) -> bool:
    record = read()
    if record is None or record.get("run_id") != run_id:
        return False
    now = _utc_now()
    record["heartbeat_at"] = _iso(now)
    record["expires_at"] = _iso(now + timedelta(seconds=int(ttl)))
    _atomic_write_json(LEASE_PATH, record)
    return True


def release(run_id: str) -> bool:
    record = read()
    if record is None:
        return True
    if record.get("run_id") != run_id:
        return False
    try:
        LEASE_PATH.unlink()
    except FileNotFoundError:
        pass
    _unlink_lock()
    return True


def is_held_fresh() -> bool:
    record = read()
    if record is None:
        return False
    expires_at = _parse_iso(record.get("expires_at"))
    return expires_at is not None and expires_at > _utc_now()


def convert(
    run_id: str,
    *,
    new_mode: str,
    new_ttl: int,
    pid: int | None = None,
    tmux_session: str | None = None,
) -> bool:
    if new_mode not in VALID_MODES:
        raise ValueError(f"invalid lease mode: {new_mode}")

    record = read()
    if record is None or record.get("run_id") != run_id:
        return False

    now = _utc_now()
    new_pid = int(pid if pid is not None else record.get("pid") or os.getpid())
    new_tmux = tmux_session if tmux_session is not None else record.get("tmux_session")
    host = str(record.get("host") or socket.gethostname())

    record["mode"] = new_mode
    record["pid"] = new_pid
    record["tmux_session"] = new_tmux
    record["holder"] = _holder(new_mode, host, new_pid, new_tmux)
    record["heartbeat_at"] = _iso(now)
    record["expires_at"] = _iso(now + timedelta(seconds=int(new_ttl)))
    _atomic_write_json(LEASE_PATH, record)
    return True


def _build_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(description="Manage the Studio autonomy lease.")
    subparsers = parser.add_subparsers(dest="command", required=True)

    heartbeat_parser = subparsers.add_parser("heartbeat")
    heartbeat_parser.add_argument("--run-id", required=True)
    heartbeat_parser.add_argument("--ttl", type=int, default=constants.BUILD_LEASE_TTL)

    release_parser = subparsers.add_parser("release")
    release_parser.add_argument("--run-id", required=True)

    convert_parser = subparsers.add_parser("convert")
    convert_parser.add_argument("--run-id", required=True)
    convert_parser.add_argument("--mode", required=True, choices=sorted(VALID_MODES))
    convert_parser.add_argument("--ttl", type=int, required=True)
    convert_parser.add_argument("--pid", type=int)
    convert_parser.add_argument("--tmux")

    subparsers.add_parser("read")
    return parser


def main(argv: list[str] | None = None) -> int:
    args = _build_parser().parse_args(argv)

    if args.command == "heartbeat":
        return 0 if heartbeat(args.run_id, ttl=args.ttl) else 1
    if args.command == "release":
        return 0 if release(args.run_id) else 1
    if args.command == "convert":
        return (
            0
            if convert(
                args.run_id,
                new_mode=args.mode,
                new_ttl=args.ttl,
                pid=args.pid,
                tmux_session=args.tmux,
            )
            else 1
        )
    if args.command == "read":
        print(json.dumps(read(), indent=2, sort_keys=True))
        return 0
    return 2


if __name__ == "__main__":
    sys.exit(main())
