Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ async def __call__(
async def _default_message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
if isinstance(message, Exception):
raise message
await anyio.lowlevel.checkpoint()


Expand Down
6 changes: 4 additions & 2 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ async def _handle_session_message(message: SessionMessage) -> None:
except Exception as e:
# Other exceptions are not expected and should be logged. We purposefully
# catch all exceptions here to avoid crashing the server.
logging.exception(f"Unhandled exception in receive loop: {e}") # pragma: no cover
logging.exception(f"Unhandled exception in receive loop: {e}")
finally:
# after the read stream is closed, we need to send errors
# to any pending requests
Expand Down Expand Up @@ -516,7 +516,9 @@ async def _handle_response(self, message: SessionMessage) -> None:
if stream:
await stream.send(message.message)
else:
await self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}"))
logging.warning(
"Received response with unknown request ID %r — dropping (request may have timed out)", response_id
)

async def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None:
"""Can be overridden by subclasses to handle a request without needing to
Expand Down
132 changes: 132 additions & 0 deletions tests/issues/test_1401_client_session_error_handling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
"""Tests for issue #1401: ClientSession should propagate transport exceptions.

Root cause: _default_message_handler called anyio.checkpoint() unconditionally,
silently dropping exceptions. The async-for loop in _receive_loop then called
`continue`, waiting for the next message that never came — hanging all pending
requests indefinitely.

Fix: _default_message_handler re-raises when the message is an Exception
(transport errors from the stream). This propagates out of _receive_loop's
async-for, triggering the finally block that closes all pending response streams
with CONNECTION_CLOSED — unblocking any in-flight callers.

Protocol-level non-fatal errors (e.g. responses with unknown request IDs from
timed-out requests) are handled inline in _handle_response with a warning log,
so they do not reach _default_message_handler and cannot kill the session.
"""

import anyio
import pytest
from anyio.abc import TaskStatus

from mcp import types
from mcp.client.session import ClientSession, _default_message_handler
from mcp.server import Server, ServerRequestContext
from mcp.shared.exceptions import MCPError
from mcp.shared.message import SessionMessage
from mcp.types import CallToolRequestParams, CallToolResult, TextContent


@pytest.mark.anyio
async def test_default_message_handler_raises_on_exception():
"""_default_message_handler must re-raise Exception instances."""
err = RuntimeError("transport failure")
with pytest.raises(RuntimeError, match="transport failure"):
await _default_message_handler(err)


@pytest.mark.anyio
async def test_default_message_handler_checkpoints_on_notification():
"""_default_message_handler should checkpoint (not raise) for non-exception messages."""
notification = types.ToolListChangedNotification(method="notifications/tools/list_changed")
# Should complete without raising
await _default_message_handler(notification)


@pytest.mark.anyio
async def test_transport_exception_unblocks_pending_request():
"""A transport exception must unblock pending requests instead of hanging them.

Before the fix: exception was swallowed by checkpoint(); async-for looped
back waiting for the next message; pending call_tool hung indefinitely.

After the fix: exception propagates out of the async-for, _receive_loop's
finally block closes all pending response streams with CONNECTION_CLOSED,
and call_tool raises MCPError rather than hanging.
"""
slow_tool_started = anyio.Event()

async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult:
slow_tool_started.set()
await anyio.sleep(60) # hangs until cancelled
return CallToolResult(content=[TextContent(type="text", text="never")]) # pragma: no cover

server = Server(
name="test",
on_call_tool=handle_call_tool,
)

server_writer, server_reader = anyio.create_memory_object_stream[SessionMessage](4)
client_writer, client_reader = anyio.create_memory_object_stream[SessionMessage | Exception](4)

call_tool_error: Exception | None = None
server_scope: anyio.CancelScope | None = None

async def run_server(*, task_status: TaskStatus[anyio.CancelScope]) -> None:
with anyio.CancelScope() as scope:
task_status.started(scope)
await server.run(server_reader, client_writer, server.create_initialization_options())

async def run_client() -> None:
nonlocal call_tool_error
async with ClientSession(client_reader, server_writer) as session: # type: ignore[arg-type]
await session.initialize()

async def inject() -> None:
await slow_tool_started.wait()
# Inject a transport exception — simulates e.g. httpx.ReadTimeout
await client_writer.send(RuntimeError("sse read timeout"))

async with anyio.create_task_group() as tg:
tg.start_soon(inject)
try:
await session.call_tool("slow")
except (MCPError, RuntimeError) as e:
call_tool_error = e
tg.cancel_scope.cancel()

assert server_scope is not None
server_scope.cancel()

async with anyio.create_task_group() as tg:
server_scope = await tg.start(run_server)
tg.start_soon(run_client)

assert call_tool_error is not None, "call_tool should have raised, not hung"


@pytest.mark.anyio
async def test_custom_message_handler_receives_exception():
"""A custom message_handler can intercept transport exceptions without re-raising."""
received: list[Exception] = []

async def capturing_handler(message: object) -> None:
if isinstance(message, Exception): # pragma: lax no cover
received.append(message) # capture — do not re-raise

server_writer, server_reader = anyio.create_memory_object_stream[SessionMessage](4)
client_writer, client_reader = anyio.create_memory_object_stream[SessionMessage | Exception](4)

async with server_reader, server_writer:
async with ClientSession(
client_reader, # type: ignore[arg-type]
server_writer.clone(),
message_handler=capturing_handler,
):
await client_writer.send(ValueError("custom handler test"))
await client_writer.aclose()
await anyio.sleep(0.05)

assert len(received) == 1
assert isinstance(received[0], ValueError)
assert str(received[0]) == "custom handler test"
Loading