Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 21 additions & 15 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,20 +236,25 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
if isinstance(ctx.session_message.message, JSONRPCRequest): # pragma: no branch
original_request_id = ctx.session_message.message.id

async with aconnect_sse(ctx.client, "GET", self.url, headers=headers) as event_source:
event_source.response.raise_for_status()
logger.debug("Resumption GET SSE connection established")
event_source: EventSource | None = None
try:
async with aconnect_sse(ctx.client, "GET", self.url, headers=headers) as es:
event_source = es
event_source.response.raise_for_status()
logger.debug("Resumption GET SSE connection established")

async for sse in event_source.aiter_sse(): # pragma: no branch
is_complete = await self._handle_sse_event(
sse,
ctx.read_stream_writer,
original_request_id,
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
)
if is_complete:
await event_source.response.aclose()
break
async for sse in event_source.aiter_sse(): # pragma: no branch
is_complete = await self._handle_sse_event(
sse,
ctx.read_stream_writer,
original_request_id,
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
)
if is_complete:
break
finally:
if event_source is not None: # pragma: no branch
await event_source.response.aclose()

async def _handle_post_request(self, ctx: RequestContext) -> None:
"""Handle a POST request with response processing."""
Expand Down Expand Up @@ -361,10 +366,11 @@ async def _handle_sse_response(
# If the SSE event indicates completion, like returning response/error
# break the loop
if is_complete:
await response.aclose()
return # Normal completion, no reconnect needed
except Exception:
logger.debug("SSE stream ended", exc_info=True) # pragma: no cover
logger.debug("SSE stream ended", exc_info=True)
finally:
await response.aclose()

# Stream ended without response - reconnect if we received an event with ID
if last_event_id is not None: # pragma: no branch
Expand Down
128 changes: 128 additions & 0 deletions tests/client/test_streamable_http_response_cleanup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import contextlib

import httpx
import pytest
from httpx_sse import ServerSentEvent

from mcp.client.streamable_http import RequestContext, StreamableHTTPTransport
from mcp.shared._context_streams import create_context_streams
from mcp.shared.message import ClientMessageMetadata, SessionMessage
from mcp.types import JSONRPCRequest


class _RaiseEventSource:
def __init__(self, response: httpx.Response) -> None:
self.response = response

async def aiter_sse(self):
yield ServerSentEvent(event="message", data="", id=None, retry=None)
raise RuntimeError("boom")


@pytest.mark.anyio
async def test_handle_sse_response_closes_response_on_exception(monkeypatch: pytest.MonkeyPatch) -> None:
closed = False

async def spy_aclose() -> None:
nonlocal closed
closed = True

response = httpx.Response(200, headers={"content-type": "text/event-stream"})
response.aclose = spy_aclose # type: ignore[method-assign]

monkeypatch.setattr("mcp.client.streamable_http.EventSource", _RaiseEventSource)

send_stream, receive_stream = create_context_streams[SessionMessage | Exception](1)
async with send_stream, receive_stream:
transport = StreamableHTTPTransport("http://example.invalid/mcp")
async with httpx.AsyncClient(transport=httpx.MockTransport(lambda _: httpx.Response(200))) as client:
ctx = RequestContext(
client=client,
session_id=None,
session_message=SessionMessage(JSONRPCRequest(method="initialize", params={}, jsonrpc="2.0", id=1)),
metadata=ClientMessageMetadata(),
read_stream_writer=send_stream,
)
await transport._handle_sse_response(response, ctx)

assert closed


@pytest.mark.anyio
async def test_handle_resumption_request_closes_response_when_aconnect_sse_raises(
monkeypatch: pytest.MonkeyPatch,
) -> None:
@contextlib.asynccontextmanager
async def fake_aconnect_sse(*_args: object, **_kwargs: object):
raise RuntimeError("connect failed")
yield

monkeypatch.setattr("mcp.client.streamable_http.aconnect_sse", fake_aconnect_sse)

send_stream, receive_stream = create_context_streams[SessionMessage | Exception](1)
async with send_stream, receive_stream:
transport = StreamableHTTPTransport("http://example.invalid/mcp")
metadata = ClientMessageMetadata(resumption_token="1")
async with httpx.AsyncClient(transport=httpx.MockTransport(lambda _: httpx.Response(200))) as client:
ctx = RequestContext(
client=client,
session_id=None,
session_message=SessionMessage(JSONRPCRequest(method="initialize", params={}, jsonrpc="2.0", id=1)),
metadata=metadata,
read_stream_writer=send_stream,
)

error: RuntimeError | None = None
try:
await transport._handle_resumption_request(ctx)
except RuntimeError as exc:
error = exc

assert error is not None
assert str(error) == "connect failed"


@pytest.mark.anyio
async def test_handle_resumption_request_closes_response_on_exception(monkeypatch: pytest.MonkeyPatch) -> None:
closed = False

async def spy_aclose() -> None:
nonlocal closed
closed = True

response = httpx.Response(
200,
headers={"content-type": "text/event-stream"},
request=httpx.Request("GET", "http://example.invalid/mcp"),
)
response.aclose = spy_aclose # type: ignore[method-assign]

@contextlib.asynccontextmanager
async def fake_aconnect_sse(*_args: object, **_kwargs: object):
yield _RaiseEventSource(response)

monkeypatch.setattr("mcp.client.streamable_http.aconnect_sse", fake_aconnect_sse)

send_stream, receive_stream = create_context_streams[SessionMessage | Exception](1)
async with send_stream, receive_stream:
transport = StreamableHTTPTransport("http://example.invalid/mcp")
metadata = ClientMessageMetadata(resumption_token="1")
async with httpx.AsyncClient(transport=httpx.MockTransport(lambda _: httpx.Response(200))) as client:
ctx = RequestContext(
client=client,
session_id=None,
session_message=SessionMessage(JSONRPCRequest(method="initialize", params={}, jsonrpc="2.0", id=1)),
metadata=metadata,
read_stream_writer=send_stream,
)

error: RuntimeError | None = None
try:
await transport._handle_resumption_request(ctx)
except RuntimeError as exc:
error = exc

assert error is not None
assert str(error) == "boom"

assert closed
Loading