#!/usr/bin/env python3
"""Console v2 MCP shim — gives embedded Claude awareness of the current selection.

Exposes two tools:
  get_current_selection()  — Returns {project_id, beat_id, take_id} from the FastAPI selection store.
  create_proposal(kind, payload) — Creates a proposal in ProposalTray via the FastAPI proposals endpoint.

Background thread subscribes to /api/events/stream SSE. On a selection/changed
event, emits MCP notifications/resources/list_changed so Claude knows to re-call
get_current_selection() if it wants fresh context.

Usage (in project .claude/mcp.json or global ~/.claude.json):
    {
        "console-shim": {
            "command": "python3",
            "args": ["/path/to/recoil/api/console_mcp_shim.py"]
        }
    }
"""

import argparse
import http.client
import json
import logging
import sys
import threading
import time
import urllib.error
import urllib.parse
import urllib.request

# stdout is the JSON-RPC channel — all logging must go to stderr.
logging.basicConfig(
    stream=sys.stderr,
    level=logging.INFO,
    format="[console-mcp] %(levelname)s %(message)s",
)
log = logging.getLogger("console-mcp")

_API_URL: str = "http://127.0.0.1:8431"  # overridden by --api-url

_TOOLS: dict[str, dict] = {}

# Only PromptRewriteProposal is wired in this shim — the payload builder uses
# beat_id + new_text which maps to the PromptRewrite contract. Other kinds need
# kind-specific payload shapes (take: prefix + key fields for ParameterChange, etc.)
# and will be added to the shim when their dispatch is wired.
_VALID_KINDS = {"PromptRewriteProposal"}


def _register_tool(name: str, description: str, input_schema: dict):
    def decorator(fn):
        _TOOLS[name] = {
            "name": name,
            "description": description,
            "inputSchema": input_schema,
            "handler": fn,
        }
        return fn

    return decorator


@_register_tool(
    name="get_current_selection",
    description=(
        "Returns the beat and take currently focused in the Console v2 UI. "
        "Call this to understand what the user is looking at before offering "
        "generation or review suggestions."
    ),
    input_schema={"type": "object", "properties": {}, "required": []},
)
def tool_get_current_selection(args: dict) -> dict:
    url = f"{_API_URL}/api/selection/current"
    req = urllib.request.Request(url, method="GET")
    try:
        with urllib.request.urlopen(req, timeout=3.0) as resp:
            return json.loads(resp.read().decode("utf-8"))
    except Exception as exc:
        raise RuntimeError(f"selection API unreachable: {exc}") from exc


@_register_tool(
    name="create_proposal",
    description=(
        "Creates a PromptRewriteProposal in the Console v2 ProposalTray. "
        "Use this to suggest a new prompt for the currently selected beat. "
        "The proposal appears in the tray for user review and approval. "
        "payload must include new_text (the proposed prompt), and optionally "
        "beat_id (inferred from current selection if absent), title, and project."
    ),
    input_schema={
        "type": "object",
        "properties": {
            "kind": {
                "type": "string",
                "description": "The proposal kind. Currently only 'PromptRewriteProposal' is supported.",
            },
            "payload": {
                "type": "object",
                "description": (
                    "Proposal payload. Keys: beat_id (str, optional — inferred from "
                    "current selection if absent), new_text (str), title (str, optional), "
                    "project (str, optional, default 'default')."
                ),
            },
        },
        "required": ["kind", "payload"],
    },
)
def tool_create_proposal(args: dict) -> dict:
    kind = args.get("kind", "")
    payload = args.get("payload", {})

    if kind not in _VALID_KINDS:
        return {
            "error": f"invalid kind {kind!r}; valid kinds: {sorted(_VALID_KINDS)}"
        }

    beat_id = payload.get("beat_id")
    if not beat_id:
        try:
            sel = tool_get_current_selection({})
            beat_id = sel.get("beat_id") or ""
        except Exception as exc:
            return {"error": f"beat_id not provided and selection API unreachable: {exc}"}
        if not beat_id:
            return {"error": "beat_id not provided and no beat is currently selected"}

    body = {
        "target": f"beat:{beat_id}",
        "kind": kind,
        "diff": [{"kind": "rewrite", "after": payload.get("new_text", "")}],
        "title": payload.get("title", f"Proposal: {kind}"),
        "est_cost_usd": 0.0,
        "est_time": "instant",
        "project": payload.get("project", "default"),
    }

    url = f"{_API_URL}/api/chat/proposals"
    body_bytes = json.dumps(body).encode("utf-8")
    req = urllib.request.Request(
        url,
        data=body_bytes,
        method="POST",
        headers={"Content-Type": "application/json"},
    )
    try:
        with urllib.request.urlopen(req, timeout=10.0) as resp:
            return json.loads(resp.read().decode("utf-8"))
    except urllib.error.HTTPError as exc:
        try:
            detail = json.loads(exc.read().decode("utf-8"))
        except Exception:
            detail = str(exc)
        return {"error": f"HTTP {exc.code} from proposals API", "detail": detail}
    except Exception as exc:
        return {"error": f"proposals API unreachable: {exc}"}


def _handle_request(request: dict) -> dict | None:
    method = request.get("method", "")
    req_id = request.get("id")
    params = request.get("params", {})

    if method == "initialize":
        return {
            "jsonrpc": "2.0",
            "id": req_id,
            "result": {
                "protocolVersion": "2024-11-05",
                "capabilities": {"tools": {}},
                "serverInfo": {"name": "recoil-console", "version": "0.1.0"},
            },
        }

    if method == "notifications/initialized":
        return None

    if method == "tools/list":
        return {
            "jsonrpc": "2.0",
            "id": req_id,
            "result": {
                "tools": [
                    {
                        "name": t["name"],
                        "description": t["description"],
                        "inputSchema": t["inputSchema"],
                    }
                    for t in _TOOLS.values()
                ]
            },
        }

    if method == "tools/call":
        tool_name = params.get("name", "")
        tool_args = params.get("arguments", {})
        tool_def = _TOOLS.get(tool_name)
        if tool_def is None:
            return {
                "jsonrpc": "2.0",
                "id": req_id,
                "result": {
                    "content": [{"type": "text", "text": f"Unknown tool: {tool_name}"}],
                    "isError": True,
                },
            }
        try:
            result = tool_def["handler"](tool_args)
            return {
                "jsonrpc": "2.0",
                "id": req_id,
                "result": {
                    "content": [
                        {
                            "type": "text",
                            "text": json.dumps(result, indent=2, default=str),
                        }
                    ],
                    "isError": False,
                },
            }
        except Exception as e:
            log.exception("Tool %s failed", tool_name)
            return {
                "jsonrpc": "2.0",
                "id": req_id,
                "result": {
                    "content": [{"type": "text", "text": f"Error: {e}"}],
                    "isError": True,
                },
            }

    return {
        "jsonrpc": "2.0",
        "id": req_id,
        "error": {"code": -32601, "message": f"Method not found: {method}"},
    }


def _sse_listener(api_url: str, write_lock: threading.Lock) -> None:
    parsed = urllib.parse.urlparse(api_url)
    _NOTIF = json.dumps(
        {"jsonrpc": "2.0", "method": "notifications/resources/list_changed"}
    )
    while True:
        try:
            conn = http.client.HTTPConnection(parsed.netloc, timeout=60)
            conn.request(
                "GET",
                "/api/events/stream",
                headers={"Accept": "text/event-stream", "Cache-Control": "no-cache"},
            )
            resp = conn.getresponse()
            block: list[str] = []
            while True:
                raw = resp.readline()
                if not raw:
                    break
                line = raw.decode("utf-8", errors="replace").rstrip("\r\n")
                if line:
                    # Guard against malformed/non-SSE lines growing unboundedly.
                    if len(line) > 65536:
                        block = []
                        continue
                    block.append(line)
                else:
                    for ln in block:
                        if ln.startswith("data:"):
                            try:
                                payload = json.loads(ln[5:].strip())
                                if payload.get("scope") == "selection/changed":
                                    with write_lock:
                                        sys.stdout.write(_NOTIF + "\n")
                                        sys.stdout.flush()
                            except (json.JSONDecodeError, KeyError):
                                pass
                    block = []
        except Exception as exc:
            log.warning("SSE listener error (will retry in 3s): %s", exc)
            time.sleep(3.0)


def _main_loop(write_lock: threading.Lock) -> None:
    for line in sys.stdin:
        line = line.strip()
        if not line:
            continue
        try:
            request = json.loads(line)
        except json.JSONDecodeError as e:
            response = {
                "jsonrpc": "2.0",
                "id": None,
                "error": {"code": -32700, "message": f"Parse error: {e}"},
            }
            with write_lock:
                sys.stdout.write(json.dumps(response) + "\n")
                sys.stdout.flush()
            continue
        response = _handle_request(request)
        if response is not None:
            with write_lock:
                sys.stdout.write(json.dumps(response) + "\n")
                sys.stdout.flush()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Console v2 MCP shim — selection context for embedded Claude."
    )
    parser.add_argument("--api-url", default="http://127.0.0.1:8431")
    cli_args = parser.parse_args()
    _API_URL = cli_args.api_url.rstrip("/")

    _write_lock = threading.Lock()

    t = threading.Thread(
        target=_sse_listener,
        args=(_API_URL, _write_lock),
        daemon=True,
        name="sse-listener",
    )
    t.start()

    _main_loop(_write_lock)
