"""Tests for console_mcp_shim — JSON-RPC 2.0 protocol + get_current_selection tool."""
from __future__ import annotations

import json
import unittest
import urllib.error
from unittest.mock import MagicMock, patch

import recoil.api.console_mcp_shim as shim


def _make_mock_resp(payload: dict) -> MagicMock:
    mock = MagicMock()
    mock.read.return_value = json.dumps(payload).encode("utf-8")
    mock.__enter__ = lambda s: s
    mock.__exit__ = MagicMock(return_value=False)
    return mock


class TestGetCurrentSelectionSuccess(unittest.TestCase):
    def test_get_current_selection_success(self):
        payload = {
            "project_id": "tartarus",
            "beat_id": "EP001_SH01",
            "take_id": "EP001_SH01_T001",
        }
        with patch("urllib.request.urlopen", return_value=_make_mock_resp(payload)):
            result = shim.tool_get_current_selection({})

        self.assertEqual(result["project_id"], "tartarus")
        self.assertEqual(result["beat_id"], "EP001_SH01")
        self.assertEqual(result["take_id"], "EP001_SH01_T001")


class TestGetCurrentSelectionApiDown(unittest.TestCase):
    def test_get_current_selection_api_down(self):
        with patch(
            "urllib.request.urlopen",
            side_effect=urllib.error.URLError("connection refused"),
        ):
            with self.assertRaises(RuntimeError) as ctx:
                shim.tool_get_current_selection({})

        self.assertIn("selection API unreachable", str(ctx.exception))


class TestToolsList(unittest.TestCase):
    def test_tools_list(self):
        request = {"jsonrpc": "2.0", "id": 1, "method": "tools/list", "params": {}}
        response = shim._handle_request(request)

        self.assertIsNotNone(response)
        self.assertEqual(response["jsonrpc"], "2.0")
        self.assertEqual(response["id"], 1)
        tools = response["result"]["tools"]
        tool_names = [t["name"] for t in tools]
        self.assertIn("get_current_selection", tool_names)
        self.assertIn("create_proposal", tool_names)

        gcs_tool = next(t for t in tools if t["name"] == "get_current_selection")
        self.assertIn("description", gcs_tool)
        self.assertIn("inputSchema", gcs_tool)


class TestInitializeHandshake(unittest.TestCase):
    def test_initialize_handshake(self):
        request = {
            "jsonrpc": "2.0",
            "id": 1,
            "method": "initialize",
            "params": {
                "protocolVersion": "2024-11-05",
                "capabilities": {},
                "clientInfo": {"name": "test", "version": "0"},
            },
        }
        response = shim._handle_request(request)

        self.assertIsNotNone(response)
        self.assertEqual(response["jsonrpc"], "2.0")
        self.assertEqual(response["id"], 1)
        result = response["result"]
        self.assertEqual(result["protocolVersion"], "2024-11-05")
        self.assertIn("capabilities", result)
        self.assertIn("serverInfo", result)
        self.assertEqual(result["serverInfo"]["name"], "recoil-console")


class TestUnknownMethod(unittest.TestCase):
    def test_unknown_method(self):
        request = {
            "jsonrpc": "2.0",
            "id": 99,
            "method": "unknown/method",
            "params": {},
        }
        response = shim._handle_request(request)

        self.assertIsNotNone(response)
        self.assertEqual(response["jsonrpc"], "2.0")
        self.assertEqual(response["id"], 99)
        self.assertIn("error", response)
        self.assertEqual(response["error"]["code"], -32601)

    def test_notifications_initialized_returns_none(self):
        request = {
            "jsonrpc": "2.0",
            "method": "notifications/initialized",
            "params": {},
        }
        self.assertIsNone(shim._handle_request(request))


class TestToolsCallDispatch(unittest.TestCase):
    def test_tools_call_unknown_tool(self):
        request = {
            "jsonrpc": "2.0",
            "id": 2,
            "method": "tools/call",
            "params": {"name": "nonexistent_tool", "arguments": {}},
        }
        response = shim._handle_request(request)

        self.assertIsNotNone(response)
        result = response["result"]
        self.assertTrue(result["isError"])
        self.assertIn("Unknown tool", result["content"][0]["text"])

    def test_tools_call_get_current_selection_success(self):
        payload = {"project_id": "tartarus", "beat_id": "EP001_SH02"}
        request = {
            "jsonrpc": "2.0",
            "id": 3,
            "method": "tools/call",
            "params": {"name": "get_current_selection", "arguments": {}},
        }
        with patch("urllib.request.urlopen", return_value=_make_mock_resp(payload)):
            response = shim._handle_request(request)

        self.assertIsNotNone(response)
        result = response["result"]
        self.assertFalse(result["isError"])
        parsed = json.loads(result["content"][0]["text"])
        self.assertEqual(parsed["project_id"], "tartarus")

    def test_tools_call_get_current_selection_error(self):
        request = {
            "jsonrpc": "2.0",
            "id": 4,
            "method": "tools/call",
            "params": {"name": "get_current_selection", "arguments": {}},
        }
        with patch(
            "urllib.request.urlopen",
            side_effect=urllib.error.URLError("refused"),
        ):
            response = shim._handle_request(request)

        self.assertIsNotNone(response)
        result = response["result"]
        self.assertTrue(result["isError"])
        self.assertIn("Error:", result["content"][0]["text"])


class TestCreateProposal(unittest.TestCase):
    def test_create_proposal_posts_correct_body(self):
        """create_proposal sends POST /api/chat/proposals with correct body and returns ID."""
        api_response = {"ok": True, "id": "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4"}

        with patch("urllib.request.urlopen", return_value=_make_mock_resp(api_response)) as mock_open:
            result = shim.tool_create_proposal({
                "kind": "PromptRewriteProposal",
                "payload": {
                    "beat_id": "EP001_SH01",
                    "new_text": "Rewritten prompt.",
                    "project": "tartarus",
                },
            })

        req_obj = mock_open.call_args[0][0]
        self.assertEqual(req_obj.method, "POST")
        self.assertIn("/api/chat/proposals", req_obj.full_url)
        sent_body = json.loads(req_obj.data.decode("utf-8"))
        self.assertEqual(sent_body["kind"], "PromptRewriteProposal")
        self.assertEqual(sent_body["target"], "beat:EP001_SH01")
        self.assertEqual(sent_body["diff"][0]["after"], "Rewritten prompt.")
        self.assertEqual(sent_body["project"], "tartarus")
        self.assertEqual(result["id"], "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4")
        self.assertTrue(result["ok"])

    def test_create_proposal_invalid_kind_returns_error(self):
        """create_proposal with invalid kind returns error dict, does not raise."""
        result = shim.tool_create_proposal({
            "kind": "NotARealKind",
            "payload": {"beat_id": "EP001_SH01", "new_text": "text"},
        })
        self.assertIsInstance(result, dict)
        self.assertIn("error", result)
        self.assertIn("NotARealKind", result.get("error", ""))

    def test_create_proposal_via_jsonrpc_dispatch(self):
        """create_proposal is reachable via tools/call JSON-RPC method."""
        api_response = {"ok": True, "id": "deadbeefdeadbeefdeadbeefdeadbeef"}
        request = {
            "jsonrpc": "2.0",
            "id": 5,
            "method": "tools/call",
            "params": {
                "name": "create_proposal",
                "arguments": {
                    "kind": "PromptRewriteProposal",
                    "payload": {
                        "beat_id": "EP001_SH02",
                        "new_text": "New text here.",
                        "project": "tartarus",
                    },
                },
            },
        }
        with patch("urllib.request.urlopen", return_value=_make_mock_resp(api_response)):
            response = shim._handle_request(request)

        self.assertIsNotNone(response)
        result = response["result"]
        self.assertFalse(result["isError"])
        parsed = json.loads(result["content"][0]["text"])
        self.assertTrue(parsed["ok"])
        self.assertEqual(parsed["id"], "deadbeefdeadbeefdeadbeefdeadbeef")


if __name__ == "__main__":
    unittest.main()
