#!/usr/bin/env python3
"""Reclaim stranded autonomy claims and stale worktrees."""

from __future__ import annotations

import os
import subprocess
from datetime import datetime, timezone
from pathlib import Path
from typing import Any

from recoil.pipeline.tools.autonomy import claim_ledger, events, lease
from recoil.pipeline.tools.autonomy import linear_client


DEFAULT_BLOCK_AFTER_FAILURES = 3
DEFAULT_WORKTREE_TTL_HOURS = 0
FAILURE_SIGNATURE = "reaper:stranded"
SESSION_WORKSPACE = Path(__file__).resolve().parents[1] / "session_workspace.sh"


def _parse_iso(value: object) -> datetime | None:
    if not isinstance(value, str) or not value:
        return None
    try:
        parsed = datetime.fromisoformat(value.replace("Z", "+00:00"))
    except ValueError:
        return None
    if parsed.tzinfo is None:
        return parsed.replace(tzinfo=timezone.utc)
    return parsed.astimezone(timezone.utc)


def _pid_is_dead(pid: object) -> bool:
    try:
        pid_int = int(pid)
    except (TypeError, ValueError):
        return False
    if pid_int <= 0:
        return False
    try:
        os.kill(pid_int, 0)
    except ProcessLookupError:
        return True
    except OSError:
        return False
    return False


def _active_claims() -> list[dict[str, Any]]:
    latest_by_issue: dict[str, dict[str, Any]] = {}
    for record in claim_ledger._records():  # noqa: SLF001 - no public all-records API exists.
        issue_id = record.get("issue_id")
        if issue_id:
            latest_by_issue[str(issue_id)] = record
    return [
        record
        for record in latest_by_issue.values()
        if record.get("state") == claim_ledger.ACTIVE_STATE and record.get("run_id")
    ]


def _all_issue_ids() -> list[str]:
    seen: set[str] = set()
    issue_ids: list[str] = []
    for record in claim_ledger._records():  # noqa: SLF001 - see _active_claims.
        issue_id = record.get("issue_id")
        if not issue_id:
            continue
        issue = str(issue_id)
        if issue in seen:
            continue
        seen.add(issue)
        issue_ids.append(issue)
    return issue_ids


def _lease_dead_for_run(run_id: str, record: dict[str, Any] | None) -> bool:
    if record is None:
        return True
    if record.get("run_id") != run_id:
        return True

    expires_at = _parse_iso(record.get("expires_at"))
    if expires_at is None:
        return False
    if expires_at > datetime.now(timezone.utc):
        return False
    return _pid_is_dead(record.get("pid"))


def _workspace_reap(*, dry_run: bool) -> subprocess.CompletedProcess[str]:
    ttl_hours = str(DEFAULT_WORKTREE_TTL_HOURS)
    cmd = [str(SESSION_WORKSPACE), "reap", "--ttl-hours", ttl_hours]
    if dry_run:
        cmd.append("--dry-run")
    return subprocess.run(cmd, capture_output=True, text=True, check=False)


def _emit(event_type: str, **fields: Any) -> None:
    events.emit(event_type, **fields)


def _maybe_mark_blocked(issue_id: str, *, dry_run: bool) -> bool:
    if claim_ledger.consecutive_failures(issue_id) < DEFAULT_BLOCK_AFTER_FAILURES:
        return False
    if dry_run:
        return False
    marked = linear_client.mark_blocked(issue_id)
    _emit("cap_tripped", issue_id=issue_id, reason="consecutive_failures", marked=marked)
    return marked


def reap(dry_run: bool = False) -> list[str]:
    """Release active claims whose owning lease is gone or provably dead."""
    reaped: list[str] = []
    lease_record = lease.read()

    for claim in _active_claims():
        run_id = str(claim.get("run_id"))
        issue_id = str(claim.get("issue_id"))
        if not _lease_dead_for_run(run_id, lease_record):
            continue

        if not dry_run:
            released = claim_ledger.release(
                issue_id,
                run_id,
                "failed",
                failure_signature=FAILURE_SIGNATURE,
            )
            if released is None:
                continue
            _emit(
                "build_killed",
                run_id=run_id,
                issue_id=issue_id,
                reason=FAILURE_SIGNATURE,
            )

        result = _workspace_reap(dry_run=dry_run)
        if not dry_run:
            _emit(
                "maintenance_ran",
                run_id=run_id,
                issue_id=issue_id,
                action="session_workspace_reap",
                returncode=result.returncode,
            )
        reaped.append(run_id)

    for issue_id in _all_issue_ids():
        _maybe_mark_blocked(issue_id, dry_run=dry_run)

    return reaped
