From cef101ec5ca05afadca5fb1116e7ca6591b799ad Mon Sep 17 00:00:00 2001 From: Yufeng He <40085740+he-yufeng@users.noreply.github.com> Date: Tue, 19 May 2026 02:36:19 +0800 Subject: [PATCH 1/2] fix: reject initialize protocol version conflicts --- src/mcp/server/streamable_http.py | 60 ++++++++++++++++++++-------- tests/shared/test_streamable_http.py | 37 +++++++++++++++++ 2 files changed, 80 insertions(+), 17 deletions(-) diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index f14201857c..982d9efc47 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -478,23 +478,11 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re await response(scope, receive, send) return - # Check if this is an initialization request - is_initialization_request = isinstance(message, JSONRPCRequest) and message.method == "initialize" - - if is_initialization_request: - # Check if the server already has an established session - if self.mcp_session_id: - # Check if request has a session ID - request_session_id = self._get_session_id(request) - - # If request has a session ID but doesn't match, return 404 - if request_session_id and request_session_id != self.mcp_session_id: # pragma: no cover - response = self._create_error_response( - "Not Found: Invalid or expired session ID", - HTTPStatus.NOT_FOUND, - ) - await response(scope, receive, send) - return + is_initialization_request = False + if isinstance(message, JSONRPCRequest) and message.method == "initialize": + is_initialization_request = True + if not await self._validate_initialization_request(message, request, send): + return elif not await self._validate_request_headers(request, send): # pragma: no cover return @@ -867,6 +855,44 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool return True + async def _validate_initialization_request(self, message: JSONRPCRequest, request: Request, send: Send) -> bool: + if not await self._validate_initialization_protocol_version(message, request, send): + return False + + if not self.mcp_session_id: + return True + + request_session_id = self._get_session_id(request) + if request_session_id and request_session_id != self.mcp_session_id: # pragma: no cover + response = self._create_error_response( + "Not Found: Invalid or expired session ID", + HTTPStatus.NOT_FOUND, + ) + await response(request.scope, request.receive, send) + return False + + return True + + async def _validate_initialization_protocol_version( + self, message: JSONRPCRequest, request: Request, send: Send + ) -> bool: + header_protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER) + body_protocol_version = str(message.params.get("protocolVersion")) if message.params else None + if ( + header_protocol_version is not None + and body_protocol_version is not None + and header_protocol_version != body_protocol_version + ): + response = self._create_error_response( + f"Bad Request: {MCP_PROTOCOL_VERSION_HEADER} header does not match initialize.params.protocolVersion", + HTTPStatus.BAD_REQUEST, + INVALID_REQUEST, + ) + await response(request.scope, request.receive, send) + return False + + return True + async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None: # pragma: no cover """Replays events that would have been sent after the specified event ID. diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 3d5770fb61..9641291979 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -1718,6 +1718,43 @@ def test_server_validates_protocol_version_header(basic_server: None, basic_serv assert response.status_code == 200 +@pytest.mark.parametrize( + ("header_version", "body_version"), + [ + ("2025-03-26", "2025-06-18"), + ("2025-06-18", "2025-03-26"), + ], +) +def test_server_rejects_initialize_protocol_version_mismatch( + basic_server: None, basic_server_url: str, header_version: str, body_version: str +): + """Test initialize rejects conflicting protocol versions in header and body.""" + init_request: dict[str, Any] = { + "jsonrpc": "2.0", + "method": "initialize", + "params": { + "clientInfo": {"name": "test-client", "version": "1.0"}, + "protocolVersion": body_version, + "capabilities": {}, + }, + "id": "init-1", + } + + response = requests.post( + f"{basic_server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_PROTOCOL_VERSION_HEADER: header_version, + }, + json=init_request, + ) + + assert response.status_code == 400 + assert MCP_PROTOCOL_VERSION_HEADER in response.text + assert "protocolVersion" in response.text + + def test_server_backwards_compatibility_no_protocol_version(basic_server: None, basic_server_url: str): """Test server accepts requests without protocol version header.""" # First initialize a session to get a valid session ID From 2cb41c354af86e663de1c31a755f39fe3e2ab9af Mon Sep 17 00:00:00 2001 From: Yufeng He <40085740+he-yufeng@users.noreply.github.com> Date: Tue, 19 May 2026 10:16:59 +0800 Subject: [PATCH 2/2] test: cover initialize protocol mismatch response --- tests/shared/test_streamable_http.py | 62 ++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 9641291979..4269178754 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -26,6 +26,7 @@ from starlette.applications import Starlette from starlette.requests import Request from starlette.routing import Mount +from starlette.types import Message from mcp import MCPError, types from mcp.client.session import ClientSession @@ -1755,6 +1756,67 @@ def test_server_rejects_initialize_protocol_version_mismatch( assert "protocolVersion" in response.text +@pytest.mark.anyio +async def test_server_rejects_initialize_protocol_version_mismatch_in_process(): + transport = StreamableHTTPServerTransport("/mcp") + write_stream, read_stream = create_context_streams[SessionMessage | Exception](1) + transport._read_stream_writer = write_stream + body = json.dumps( + { + "jsonrpc": "2.0", + "method": "initialize", + "params": { + "clientInfo": {"name": "test-client", "version": "1.0"}, + "protocolVersion": "2025-06-18", + "capabilities": {}, + }, + "id": "init-1", + } + ).encode() + sent: list[Message] = [] + received_body = False + + async def receive() -> Message: + nonlocal received_body + if received_body: + return {"type": "http.disconnect"} + + received_body = True + return {"type": "http.request", "body": body, "more_body": False} + + async def send(message: Message) -> None: + sent.append(message) + + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "method": "POST", + "path": "/mcp", + "raw_path": b"/mcp", + "query_string": b"", + "headers": [ + (b"accept", b"application/json, text/event-stream"), + (b"content-type", b"application/json"), + (MCP_PROTOCOL_VERSION_HEADER.encode(), b"2025-03-26"), + ], + "client": ("testclient", 50000), + "server": ("testserver", 80), + "scheme": "http", + } + + try: + await transport.handle_request(scope, receive, send) + assert await receive() == {"type": "http.disconnect"} + finally: + await write_stream.aclose() + await read_stream.aclose() + + assert any(message["type"] == "http.response.start" and message["status"] == 400 for message in sent) + response_body = b"".join(message.get("body", b"") for message in sent if message["type"] == "http.response.body") + assert MCP_PROTOCOL_VERSION_HEADER.encode() in response_body + assert b"protocolVersion" in response_body + + def test_server_backwards_compatibility_no_protocol_version(basic_server: None, basic_server_url: str): """Test server accepts requests without protocol version header.""" # First initialize a session to get a valid session ID