diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 9a119c633..fdd024b7d 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -236,20 +236,25 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: if isinstance(ctx.session_message.message, JSONRPCRequest): # pragma: no branch original_request_id = ctx.session_message.message.id - async with aconnect_sse(ctx.client, "GET", self.url, headers=headers) as event_source: - event_source.response.raise_for_status() - logger.debug("Resumption GET SSE connection established") + event_source: EventSource | None = None + try: + async with aconnect_sse(ctx.client, "GET", self.url, headers=headers) as es: + event_source = es + event_source.response.raise_for_status() + logger.debug("Resumption GET SSE connection established") - async for sse in event_source.aiter_sse(): # pragma: no branch - is_complete = await self._handle_sse_event( - sse, - ctx.read_stream_writer, - original_request_id, - ctx.metadata.on_resumption_token_update if ctx.metadata else None, - ) - if is_complete: - await event_source.response.aclose() - break + async for sse in event_source.aiter_sse(): # pragma: no branch + is_complete = await self._handle_sse_event( + sse, + ctx.read_stream_writer, + original_request_id, + ctx.metadata.on_resumption_token_update if ctx.metadata else None, + ) + if is_complete: + break + finally: + if event_source is not None: # pragma: no branch + await event_source.response.aclose() async def _handle_post_request(self, ctx: RequestContext) -> None: """Handle a POST request with response processing.""" @@ -361,10 +366,11 @@ async def _handle_sse_response( # If the SSE event indicates completion, like returning response/error # break the loop if is_complete: - await response.aclose() return # Normal completion, no reconnect needed except Exception: - logger.debug("SSE stream ended", exc_info=True) # pragma: no cover + logger.debug("SSE stream ended", exc_info=True) + finally: + await response.aclose() # Stream ended without response - reconnect if we received an event with ID if last_event_id is not None: # pragma: no branch diff --git a/tests/client/test_streamable_http_response_cleanup.py b/tests/client/test_streamable_http_response_cleanup.py new file mode 100644 index 000000000..0fa4ede44 --- /dev/null +++ b/tests/client/test_streamable_http_response_cleanup.py @@ -0,0 +1,128 @@ +import contextlib + +import httpx +import pytest +from httpx_sse import ServerSentEvent + +from mcp.client.streamable_http import RequestContext, StreamableHTTPTransport +from mcp.shared._context_streams import create_context_streams +from mcp.shared.message import ClientMessageMetadata, SessionMessage +from mcp.types import JSONRPCRequest + + +class _RaiseEventSource: + def __init__(self, response: httpx.Response) -> None: + self.response = response + + async def aiter_sse(self): + yield ServerSentEvent(event="message", data="", id=None, retry=None) + raise RuntimeError("boom") + + +@pytest.mark.anyio +async def test_handle_sse_response_closes_response_on_exception(monkeypatch: pytest.MonkeyPatch) -> None: + closed = False + + async def spy_aclose() -> None: + nonlocal closed + closed = True + + response = httpx.Response(200, headers={"content-type": "text/event-stream"}) + response.aclose = spy_aclose # type: ignore[method-assign] + + monkeypatch.setattr("mcp.client.streamable_http.EventSource", _RaiseEventSource) + + send_stream, receive_stream = create_context_streams[SessionMessage | Exception](1) + async with send_stream, receive_stream: + transport = StreamableHTTPTransport("http://example.invalid/mcp") + async with httpx.AsyncClient(transport=httpx.MockTransport(lambda _: httpx.Response(200))) as client: + ctx = RequestContext( + client=client, + session_id=None, + session_message=SessionMessage(JSONRPCRequest(method="initialize", params={}, jsonrpc="2.0", id=1)), + metadata=ClientMessageMetadata(), + read_stream_writer=send_stream, + ) + await transport._handle_sse_response(response, ctx) + + assert closed + + +@pytest.mark.anyio +async def test_handle_resumption_request_closes_response_when_aconnect_sse_raises( + monkeypatch: pytest.MonkeyPatch, +) -> None: + @contextlib.asynccontextmanager + async def fake_aconnect_sse(*_args: object, **_kwargs: object): + raise RuntimeError("connect failed") + yield + + monkeypatch.setattr("mcp.client.streamable_http.aconnect_sse", fake_aconnect_sse) + + send_stream, receive_stream = create_context_streams[SessionMessage | Exception](1) + async with send_stream, receive_stream: + transport = StreamableHTTPTransport("http://example.invalid/mcp") + metadata = ClientMessageMetadata(resumption_token="1") + async with httpx.AsyncClient(transport=httpx.MockTransport(lambda _: httpx.Response(200))) as client: + ctx = RequestContext( + client=client, + session_id=None, + session_message=SessionMessage(JSONRPCRequest(method="initialize", params={}, jsonrpc="2.0", id=1)), + metadata=metadata, + read_stream_writer=send_stream, + ) + + error: RuntimeError | None = None + try: + await transport._handle_resumption_request(ctx) + except RuntimeError as exc: + error = exc + + assert error is not None + assert str(error) == "connect failed" + + +@pytest.mark.anyio +async def test_handle_resumption_request_closes_response_on_exception(monkeypatch: pytest.MonkeyPatch) -> None: + closed = False + + async def spy_aclose() -> None: + nonlocal closed + closed = True + + response = httpx.Response( + 200, + headers={"content-type": "text/event-stream"}, + request=httpx.Request("GET", "http://example.invalid/mcp"), + ) + response.aclose = spy_aclose # type: ignore[method-assign] + + @contextlib.asynccontextmanager + async def fake_aconnect_sse(*_args: object, **_kwargs: object): + yield _RaiseEventSource(response) + + monkeypatch.setattr("mcp.client.streamable_http.aconnect_sse", fake_aconnect_sse) + + send_stream, receive_stream = create_context_streams[SessionMessage | Exception](1) + async with send_stream, receive_stream: + transport = StreamableHTTPTransport("http://example.invalid/mcp") + metadata = ClientMessageMetadata(resumption_token="1") + async with httpx.AsyncClient(transport=httpx.MockTransport(lambda _: httpx.Response(200))) as client: + ctx = RequestContext( + client=client, + session_id=None, + session_message=SessionMessage(JSONRPCRequest(method="initialize", params={}, jsonrpc="2.0", id=1)), + metadata=metadata, + read_stream_writer=send_stream, + ) + + error: RuntimeError | None = None + try: + await transport._handle_resumption_request(ctx) + except RuntimeError as exc: + error = exc + + assert error is not None + assert str(error) == "boom" + + assert closed