#!/usr/bin/env python3
"""Append-only local claim ledger for Studio autonomy."""

from __future__ import annotations

import json
import socket
from datetime import datetime, timezone
from pathlib import Path
from typing import Any

from recoil.pipeline.tools.autonomy import constants


CLAIM_LEDGER = constants.CLAIM_LEDGER
ACTIVE_STATE = "active"
TERMINAL_STATES = {"released", "completed", "failed"}
VALID_STATES = {ACTIVE_STATE, *TERMINAL_STATES}


def _utc_now_iso() -> str:
    return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")


def _records() -> list[dict[str, Any]]:
    try:
        lines = CLAIM_LEDGER.read_text(encoding="utf-8").splitlines()
    except FileNotFoundError:
        return []

    records: list[dict[str, Any]] = []
    for line in lines:
        if not line.strip():
            continue
        data = json.loads(line)
        if isinstance(data, dict):
            records.append(data)
    return records


def _append(record: dict[str, Any]) -> dict[str, Any]:
    CLAIM_LEDGER.parent.mkdir(parents=True, exist_ok=True)
    with CLAIM_LEDGER.open("a", encoding="utf-8") as handle:
        handle.write(json.dumps(record, sort_keys=True, separators=(",", ":")) + "\n")
    return record


def _attempt(record: dict[str, Any]) -> int:
    raw_attempt = record.get("attempt")
    if isinstance(raw_attempt, int):
        return raw_attempt
    try:
        return int(str(raw_attempt))
    except (TypeError, ValueError):
        pass

    record_id = str(record.get("record_id") or "")
    parts = record_id.split(":", 2)
    if len(parts) >= 2:
        try:
            return int(parts[1])
        except ValueError:
            pass
    return 0


def _issue_records(issue_id: str) -> list[dict[str, Any]]:
    return [record for record in _records() if record.get("issue_id") == issue_id]


def active_claim(issue_id: str) -> dict[str, Any] | None:
    """Return the current active claim for ``issue_id``, if the latest state is active."""
    issue_records = _issue_records(issue_id)
    if not issue_records:
        return None

    latest = issue_records[-1]
    if latest.get("state") == ACTIVE_STATE:
        return latest
    return None


def claim(
    issue_id: str, issue_identifier: str, run_id: str, night_id: str
) -> dict[str, Any] | None:
    """Append a new active claim unless this issue is already claimed."""
    if active_claim(issue_id) is not None:
        return None

    prior_attempts = max((_attempt(record) for record in _issue_records(issue_id)), default=0)
    attempt = prior_attempts + 1
    claimed_at = _utc_now_iso()
    record = {
        "record_id": f"{issue_id}:{attempt}:{run_id}",
        "attempt": attempt,
        "issue_id": issue_id,
        "issue_identifier": issue_identifier,
        "run_id": run_id,
        "night_id": night_id,
        "state": ACTIVE_STATE,
        "host": socket.gethostname(),
        "claimed_at": claimed_at,
    }
    return _append(record)


def release(
    issue_id: str,
    run_id: str,
    state: str,
    failure_signature: str | None = None,
) -> dict[str, Any] | None:
    """Append a terminal claim record for the active ``issue_id``/``run_id`` pair."""
    if state not in TERMINAL_STATES:
        raise ValueError(f"invalid terminal claim state: {state}")

    current = active_claim(issue_id)
    if current is None or current.get("run_id") != run_id:
        return None

    record = dict(current)
    record["state"] = state
    record["released_at"] = _utc_now_iso()
    record.pop("failure_signature", None)
    if failure_signature is not None:
        record["failure_signature"] = failure_signature
    return _append(record)


def consecutive_failures(issue_id: str) -> int:
    """Count trailing terminal failures that share the latest failure signature."""
    terminal_records = [
        record
        for record in _issue_records(issue_id)
        if record.get("state") in TERMINAL_STATES
    ]
    if not terminal_records or terminal_records[-1].get("state") != "failed":
        return 0

    signature = terminal_records[-1].get("failure_signature")
    count = 0
    for record in reversed(terminal_records):
        if (
            record.get("state") == "failed"
            and record.get("failure_signature") == signature
        ):
            count += 1
            continue
        break
    return count
