#!/usr/bin/env python3
"""Fail-closed preflight and human-yield checks for the autonomy tick."""

from __future__ import annotations

import json
import os
import shutil
import socket
import subprocess
from datetime import datetime, timedelta
from pathlib import Path
from typing import Callable, Mapping, Sequence

from recoil.pipeline.tools.autonomy import constants

_REPO_ROOT = Path(__file__).resolve().parents[3]
_SESSION_WORKSPACE = _REPO_ROOT / "recoil/pipeline/tools/session_workspace.sh"
_KNOWN_STUDIO_HOSTS = frozenset(
    {
        "joes-mac-studio-70960.local",
        "joes-mac-studio-70960",
        "joes-mac-studio-11047",
        "joes-mac-studio",
    }
)
_RECENT_LOGIN_MINUTES = 15

Runner = Callable[..., subprocess.CompletedProcess[str]]


def preflight(
    *,
    runner: Runner = subprocess.run,
    environ: Mapping[str, str] | None = None,
    hostname: str | None = None,
) -> tuple[bool, str]:
    """Verify tick prerequisites and return ``(ok, reason)``."""
    del runner  # Reserved for parity with human_active's injectable subprocess shape.
    env = os.environ if environ is None else environ

    for binary in ("claude", "codex", "git"):
        if shutil.which(binary) is None:
            return False, f"missing {binary}"

    shlock = Path(constants.SHLOCK)
    if not (shlock.exists() and os.access(shlock, os.X_OK)):
        return False, "missing shlock"

    if not env.get("LINEAR_API_KEY"):
        return False, "missing LINEAR_API_KEY"
    if not env.get("AUTONOMY_LINEAR_TEAM"):
        return False, "missing AUTONOMY_LINEAR_TEAM"

    try:
        constants.ensure_state_dir()
    except Exception as exc:
        return False, f"state dir failed: {exc}"

    actual_host = hostname or socket.gethostname()
    expected_host = env.get("AUTONOMY_HOST")
    if expected_host:
        if not _same_host(actual_host, expected_host):
            return False, "wrong host"
    elif _normalize_host(actual_host) not in _KNOWN_STUDIO_HOSTS:
        return False, "wrong host"

    return True, "ok"


def human_active(
    repo_root: str | Path,
    *,
    runner: Runner = subprocess.run,
    now: datetime | None = None,
    recent_login_minutes: int = _RECENT_LOGIN_MINUTES,
) -> tuple[bool, str]:
    """Return true when soft-yield heuristics see likely human activity."""
    repo = Path(repo_root)

    active, reason = _canonical_dirty(repo, runner)
    if active:
        return True, reason

    active, reason = _canonical_off_main(repo, runner)
    if active:
        return True, reason

    active, reason = _recent_login(runner, now=now, minutes=recent_login_minutes)
    if active:
        return True, reason

    active, reason = _tmux_active(runner)
    if active:
        return True, reason

    active, reason = _workspace_conflict(repo, runner)
    if active:
        return True, reason

    return False, "idle"


def _canonical_dirty(repo: Path, runner: Runner) -> tuple[bool, str]:
    result = _run(
        runner,
        ["git", "-C", str(repo), "status", "--porcelain"],
        error_reason="git_status_error",
    )
    if isinstance(result, tuple):
        return result
    if result.stdout.strip():
        return True, "canonical_dirty"
    return False, ""


def _canonical_off_main(repo: Path, runner: Runner) -> tuple[bool, str]:
    result = _run(
        runner,
        ["git", "-C", str(repo), "rev-parse", "--abbrev-ref", "HEAD"],
        error_reason="git_branch_error",
    )
    if isinstance(result, tuple):
        return result
    if result.stdout.strip() != "main":
        return True, "canonical_off_main"
    return False, ""


def _recent_login(
    runner: Runner,
    *,
    now: datetime | None,
    minutes: int,
) -> tuple[bool, str]:
    result = _run(runner, ["who"], error_reason="who_error")
    if isinstance(result, tuple):
        return result

    lines = [line for line in result.stdout.splitlines() if line.strip()]
    if not lines:
        return False, ""

    reference = now or datetime.now()
    threshold = reference - timedelta(minutes=minutes)
    for line in lines:
        login_at = _parse_who_time(line, reference)
        if login_at is None:
            return True, "recent_login"
        if threshold <= login_at <= reference + timedelta(minutes=1):
            return True, "recent_login"
    return False, ""


def _tmux_active(runner: Runner) -> tuple[bool, str]:
    try:
        result = runner(
            ["tmux", "ls"],
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
            check=False,
        )
    except Exception:
        return True, "tmux_error"

    if result.returncode != 0:
        stderr = result.stderr.lower()
        no_sessions = (
            result.returncode == 1
            and not result.stdout.strip()
            and (
                not stderr.strip()
                or "no server running" in stderr
                or "failed to connect" in stderr
            )
        )
        if no_sessions:
            return False, ""
        return True, "tmux_error"

    for line in result.stdout.splitlines():
        name = line.split(":", 1)[0].strip()
        if name and not name.startswith("autonomy-"):
            return True, "tmux_active"
    return False, ""


def _workspace_conflict(repo: Path, runner: Runner) -> tuple[bool, str]:
    result = _run(
        runner,
        [str(_SESSION_WORKSPACE), "observe", "--json"],
        error_reason="workspace_error",
        cwd=str(repo),
    )
    if isinstance(result, tuple):
        return result

    try:
        observed = json.loads(result.stdout)
    except Exception:
        return True, "workspace_error"

    if observed.get("same_host_checkout_conflict") or observed.get(
        "mutating_shared_checkout"
    ):
        return True, "workspace_conflict"
    return False, ""


def _run(
    runner: Runner,
    args: Sequence[str],
    *,
    error_reason: str,
    cwd: str | None = None,
) -> subprocess.CompletedProcess[str] | tuple[bool, str]:
    try:
        result = runner(
            list(args),
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
            check=False,
            cwd=cwd,
        )
    except Exception:
        return True, error_reason
    if result.returncode != 0:
        return True, error_reason
    return result


def _parse_who_time(line: str, reference: datetime) -> datetime | None:
    fields = line.split()
    candidates = [fields[-3:], fields[-2:]]

    for parts in candidates:
        candidate = " ".join(parts)
        try:
            if len(parts) == 2:
                return datetime.strptime(candidate, "%Y-%m-%d %H:%M")
            if len(parts) == 3:
                parsed = datetime.strptime(
                    f"{reference.year} {candidate}", "%Y %b %d %H:%M"
                )
                if parsed > reference + timedelta(days=1):
                    parsed = parsed.replace(year=reference.year - 1)
                return parsed
        except ValueError:
            continue
    return None


def _normalize_host(value: str) -> str:
    return value.strip().lower()


def _same_host(left: str, right: str) -> bool:
    left_norm = _normalize_host(left)
    right_norm = _normalize_host(right)
    return (
        left_norm == right_norm
        or left_norm.split(".", 1)[0] == right_norm.split(".", 1)[0]
    )
