diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java index 0a8dff363..7f22048a7 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java @@ -429,6 +429,9 @@ public Mono sendMessage(McpSchema.JSONRPCMessage sentMessage) { final AtomicReference disposableRef = new AtomicReference<>(); final McpTransportSession transportSession = this.activeSession.get(); + // https://github.com/modelcontextprotocol/java-sdk/issues/889 + Object requestId = (sentMessage instanceof McpSchema.JSONRPCRequest req) ? req.id() : null; + var uri = Utils.resolveUri(this.baseUri, this.endpoint); String jsonBody = this.toString(sentMessage); @@ -586,12 +589,19 @@ else if (statusCode == BAD_REQUEST) { }) .flatMap(jsonRpcMessage -> this.handler.get().apply(Mono.just(jsonRpcMessage))) .onErrorMap(CompletionException.class, t -> t.getCause()) - .onErrorComplete(t -> { - // handle the error first + .onErrorResume(t -> { this.handleException(t); - // inform the caller of sendMessage deliveredSink.error(t); - return true; + if (requestId != null) { + // Emit synthetic error so pending response is resolved + logger.warn("Body-level error for request {}, emitting synthetic error response", requestId, t); + McpSchema.JSONRPCResponse errorResponse = new McpSchema.JSONRPCResponse( + McpSchema.JSONRPC_VERSION, requestId, null, + new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + "Transport error during response streaming: " + t.getMessage(), null)); + return this.handler.get().apply(Mono.just(errorResponse)); + } + return Flux.empty(); }) .doFinally(s -> { logger.debug("SendMessage finally: {}", s); diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportBodyErrorTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportBodyErrorTest.java new file mode 100644 index 000000000..4d13204bb --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportBodyErrorTest.java @@ -0,0 +1,194 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.IOException; +import java.io.OutputStream; +import java.net.InetSocketAddress; +import java.nio.charset.StandardCharsets; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import com.sun.net.httpserver.HttpServer; + +import io.modelcontextprotocol.server.transport.TomcatTestUtil; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpTransportException; +import io.modelcontextprotocol.spec.ProtocolVersions; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import reactor.test.StepVerifier; + +/** + * Tests for body-level error handling in {@link HttpClientStreamableHttpTransport}. + * + * @author James Kennedy + * @see #889 + */ +@Timeout(15) +public class HttpClientStreamableHttpTransportBodyErrorTest { + + private static final int PORT = TomcatTestUtil.findAvailablePort(); + + private static final String HOST = "http://localhost:" + PORT; + + private HttpServer server; + + private McpClientTransport transport; + + private final AtomicBoolean returnMalformedSse = new AtomicBoolean(false); + + @BeforeEach + void startServer() throws IOException { + server = HttpServer.create(new InetSocketAddress(PORT), 0); + + server.createContext("/mcp", exchange -> { + String method = exchange.getRequestMethod(); + + if ("DELETE".equals(method)) { + exchange.sendResponseHeaders(200, 0); + exchange.close(); + return; + } + + if ("GET".equals(method)) { + exchange.sendResponseHeaders(405, 0); + exchange.close(); + return; + } + + if (returnMalformedSse.get()) { + exchange.getResponseHeaders().set("Content-Type", "text/event-stream"); + exchange.sendResponseHeaders(200, 0); + OutputStream os = exchange.getResponseBody(); + os.write("event: message\ndata: {not valid json\n\n".getBytes(StandardCharsets.UTF_8)); + os.flush(); + exchange.close(); + return; + } + + exchange.getResponseHeaders().set("Content-Type", "application/json"); + String response = "{\"jsonrpc\":\"2.0\",\"result\":{},\"id\":\"init-id\"}"; + exchange.sendResponseHeaders(200, response.length()); + exchange.getResponseBody().write(response.getBytes(StandardCharsets.UTF_8)); + exchange.close(); + }); + + server.setExecutor(null); + server.start(); + + transport = HttpClientStreamableHttpTransport.builder(HOST).build(); + } + + @AfterEach + void stopServer() { + if (server != null) { + server.stop(0); + } + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + } + + @Test + void bodyErrorOnSseStreamPropagatesError() { + StepVerifier.create(transport.connect(msg -> msg)).verifyComplete(); + + returnMalformedSse.set(true); + + var request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, "req-123", + new McpSchema.InitializeRequest(ProtocolVersions.MCP_2025_03_26, + McpSchema.ClientCapabilities.builder().roots(true).build(), + new McpSchema.Implementation("Test Client", "1.0.0"))); + + StepVerifier.create(transport.sendMessage(request)).expectError(McpTransportException.class).verify(); + } + + @Test + void bodyErrorOnJsonResponseEmitsSyntheticErrorResponse() throws InterruptedException { + var handlerMessages = new CopyOnWriteArrayList(); + CountDownLatch errorResponseLatch = new CountDownLatch(1); + + StepVerifier.create(transport.connect(msg -> msg.doOnNext(m -> { + handlerMessages.add(m); + if (m instanceof McpSchema.JSONRPCResponse resp && resp.error() != null) { + errorResponseLatch.countDown(); + } + }))).verifyComplete(); + + server.removeContext("/mcp"); + server.createContext("/mcp", exchange -> { + String method = exchange.getRequestMethod(); + + if ("DELETE".equals(method)) { + exchange.sendResponseHeaders(200, 0); + exchange.close(); + return; + } + + if ("GET".equals(method)) { + exchange.sendResponseHeaders(405, 0); + exchange.close(); + return; + } + + exchange.getResponseHeaders().set("Content-Type", "application/json"); + byte[] malformed = "{not valid json".getBytes(StandardCharsets.UTF_8); + exchange.sendResponseHeaders(200, malformed.length); + exchange.getResponseBody().write(malformed); + exchange.close(); + }); + + var request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, "req-456", + new McpSchema.InitializeRequest(ProtocolVersions.MCP_2025_03_26, + McpSchema.ClientCapabilities.builder().roots(true).build(), + new McpSchema.Implementation("Test Client", "1.0.0"))); + + StepVerifier.create(transport.sendMessage(request)).verifyComplete(); + + assertThat(errorResponseLatch.await(5, TimeUnit.SECONDS)) + .as("Handler should receive synthetic error response within 5 seconds") + .isTrue(); + + var errorResponses = handlerMessages.stream() + .filter(m -> m instanceof McpSchema.JSONRPCResponse resp && resp.error() != null) + .map(m -> (McpSchema.JSONRPCResponse) m) + .toList(); + + assertThat(errorResponses).hasSize(1); + McpSchema.JSONRPCResponse errorResponse = errorResponses.get(0); + assertThat(errorResponse.id()).isEqualTo("req-456"); + assertThat(errorResponse.error().code()).isEqualTo(McpSchema.ErrorCodes.INTERNAL_ERROR); + assertThat(errorResponse.error().message()).contains("Transport error"); + } + + @Test + void bodyErrorOnNotificationDoesNotEmitSyntheticResponse() throws InterruptedException { + var handlerMessages = new CopyOnWriteArrayList(); + + StepVerifier.create(transport.connect(msg -> msg.doOnNext(handlerMessages::add))).verifyComplete(); + + returnMalformedSse.set(true); + + var notification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, "notifications/cancelled", + null); + + StepVerifier.create(transport.sendMessage(notification)).expectError(McpTransportException.class).verify(); + + Thread.sleep(500); + + var errorResponses = handlerMessages.stream() + .filter(m -> m instanceof McpSchema.JSONRPCResponse resp && resp.error() != null) + .toList(); + + assertThat(errorResponses).isEmpty(); + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java index 0b5ce55cd..4187a157a 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java @@ -291,6 +291,9 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { final AtomicReference disposableRef = new AtomicReference<>(); final McpTransportSession transportSession = this.activeSession.get(); + // https://github.com/modelcontextprotocol/java-sdk/issues/889 + Object requestId = (message instanceof McpSchema.JSONRPCRequest req) ? req.id() : null; + Disposable connection = Flux.deferContextual(ctx -> webClient.post() .uri(this.endpoint) .accept(MediaType.APPLICATION_JSON, MediaType.TEXT_EVENT_STREAM) @@ -356,23 +359,25 @@ else if (mediaType.isCompatibleWith(MediaType.APPLICATION_JSON)) { } return this.extractError(response, sessionRepresentation); } - })) - .flatMap(jsonRpcMessage -> this.handler.get().apply(Mono.just(jsonRpcMessage))) - .onErrorComplete(t -> { - // handle the error first + })).flatMap(jsonRpcMessage -> this.handler.get().apply(Mono.just(jsonRpcMessage))).onErrorResume(t -> { this.handleException(t); - // inform the caller of sendMessage sink.error(t); - return true; - }) - .doFinally(s -> { + if (requestId != null) { + // Emit synthetic error so pending response is resolved + logger.warn("Body-level error for request {}, emitting synthetic error response", requestId, t); + McpSchema.JSONRPCResponse errorResponse = new McpSchema.JSONRPCResponse( + McpSchema.JSONRPC_VERSION, requestId, null, + new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + "Transport error during response streaming: " + t.getMessage(), null)); + return this.handler.get().apply(Mono.just(errorResponse)); + } + return Flux.empty(); + }).doFinally(s -> { Disposable ref = disposableRef.getAndSet(null); if (ref != null) { transportSession.removeConnection(ref); } - }) - .contextWrite(sink.contextView()) - .subscribe(); + }).contextWrite(sink.contextView()).subscribe(); disposableRef.set(connection); transportSession.addConnection(connection); }); diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransportBodyErrorTest.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransportBodyErrorTest.java new file mode 100644 index 000000000..9a87e45cd --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransportBodyErrorTest.java @@ -0,0 +1,206 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.IOException; +import java.io.OutputStream; +import java.net.InetSocketAddress; +import java.nio.charset.StandardCharsets; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.springframework.web.reactive.function.client.WebClient; + +import com.sun.net.httpserver.HttpServer; + +import io.modelcontextprotocol.server.TestUtil; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.ProtocolVersions; +import reactor.test.StepVerifier; + +/** + * Tests for body-level error handling in {@link WebClientStreamableHttpTransport}. + * + * @author James Kennedy + * @see #889 + */ +@Timeout(15) +public class WebClientStreamableHttpTransportBodyErrorTest { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String HOST = "http://localhost:" + PORT; + + private HttpServer server; + + private McpClientTransport transport; + + private final AtomicBoolean returnMalformedSse = new AtomicBoolean(false); + + @BeforeEach + void startServer() throws IOException { + server = HttpServer.create(new InetSocketAddress(PORT), 0); + + server.createContext("/mcp", exchange -> { + String method = exchange.getRequestMethod(); + + if ("GET".equals(method)) { + exchange.sendResponseHeaders(405, 0); + exchange.close(); + return; + } + + if (returnMalformedSse.get()) { + exchange.getResponseHeaders().set("Content-Type", "text/event-stream"); + exchange.sendResponseHeaders(200, 0); + OutputStream os = exchange.getResponseBody(); + os.write("event: message\ndata: {not valid json\n\n".getBytes(StandardCharsets.UTF_8)); + os.flush(); + exchange.close(); + return; + } + + exchange.getResponseHeaders().set("Content-Type", "application/json"); + String response = "{\"jsonrpc\":\"2.0\",\"result\":{},\"id\":\"init-id\"}"; + exchange.sendResponseHeaders(200, response.length()); + exchange.getResponseBody().write(response.getBytes(StandardCharsets.UTF_8)); + exchange.close(); + }); + + server.setExecutor(null); + server.start(); + + transport = WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl(HOST)).build(); + } + + @AfterEach + void stopServer() { + if (server != null) { + server.stop(0); + } + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + } + + @Test + void bodyErrorOnSseStreamEmitsSyntheticErrorResponse() throws InterruptedException { + var handlerMessages = new CopyOnWriteArrayList(); + CountDownLatch errorResponseLatch = new CountDownLatch(1); + + StepVerifier.create(transport.connect(msg -> msg.doOnNext(m -> { + handlerMessages.add(m); + if (m instanceof McpSchema.JSONRPCResponse resp && resp.error() != null) { + errorResponseLatch.countDown(); + } + }))).verifyComplete(); + + returnMalformedSse.set(true); + + var request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, "req-123", + new McpSchema.InitializeRequest(ProtocolVersions.MCP_2025_03_26, + McpSchema.ClientCapabilities.builder().roots(true).build(), + new McpSchema.Implementation("Test Client", "1.0.0"))); + + StepVerifier.create(transport.sendMessage(request)).verifyComplete(); + + assertThat(errorResponseLatch.await(5, TimeUnit.SECONDS)) + .as("Handler should receive synthetic error response within 5 seconds") + .isTrue(); + + var errorResponses = handlerMessages.stream() + .filter(m -> m instanceof McpSchema.JSONRPCResponse resp && resp.error() != null) + .map(m -> (McpSchema.JSONRPCResponse) m) + .toList(); + + assertThat(errorResponses).hasSize(1); + McpSchema.JSONRPCResponse errorResponse = errorResponses.get(0); + assertThat(errorResponse.id()).isEqualTo("req-123"); + assertThat(errorResponse.error().code()).isEqualTo(McpSchema.ErrorCodes.INTERNAL_ERROR); + assertThat(errorResponse.error().message()).contains("Transport error"); + } + + @Test + void bodyErrorOnJsonResponseEmitsSyntheticErrorResponse() throws InterruptedException { + var handlerMessages = new CopyOnWriteArrayList(); + CountDownLatch errorResponseLatch = new CountDownLatch(1); + + StepVerifier.create(transport.connect(msg -> msg.doOnNext(m -> { + handlerMessages.add(m); + if (m instanceof McpSchema.JSONRPCResponse resp && resp.error() != null) { + errorResponseLatch.countDown(); + } + }))).verifyComplete(); + + server.removeContext("/mcp"); + server.createContext("/mcp", exchange -> { + String method = exchange.getRequestMethod(); + + if ("GET".equals(method)) { + exchange.sendResponseHeaders(405, 0); + exchange.close(); + return; + } + + exchange.getResponseHeaders().set("Content-Type", "application/json"); + byte[] malformed = "{not valid json".getBytes(StandardCharsets.UTF_8); + exchange.sendResponseHeaders(200, malformed.length); + exchange.getResponseBody().write(malformed); + exchange.close(); + }); + + var request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, "req-456", + new McpSchema.InitializeRequest(ProtocolVersions.MCP_2025_03_26, + McpSchema.ClientCapabilities.builder().roots(true).build(), + new McpSchema.Implementation("Test Client", "1.0.0"))); + + StepVerifier.create(transport.sendMessage(request)).verifyComplete(); + + assertThat(errorResponseLatch.await(5, TimeUnit.SECONDS)) + .as("Handler should receive synthetic error response within 5 seconds") + .isTrue(); + + var errorResponses = handlerMessages.stream() + .filter(m -> m instanceof McpSchema.JSONRPCResponse resp && resp.error() != null) + .map(m -> (McpSchema.JSONRPCResponse) m) + .toList(); + + assertThat(errorResponses).hasSize(1); + McpSchema.JSONRPCResponse errorResponse = errorResponses.get(0); + assertThat(errorResponse.id()).isEqualTo("req-456"); + assertThat(errorResponse.error().code()).isEqualTo(McpSchema.ErrorCodes.INTERNAL_ERROR); + assertThat(errorResponse.error().message()).contains("Transport error"); + } + + @Test + void bodyErrorOnNotificationDoesNotEmitSyntheticResponse() throws InterruptedException { + var handlerMessages = new CopyOnWriteArrayList(); + + StepVerifier.create(transport.connect(msg -> msg.doOnNext(handlerMessages::add))).verifyComplete(); + + returnMalformedSse.set(true); + + var notification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, "notifications/cancelled", + null); + + StepVerifier.create(transport.sendMessage(notification)).verifyComplete(); + + Thread.sleep(500); + + var errorResponses = handlerMessages.stream() + .filter(m -> m instanceof McpSchema.JSONRPCResponse resp && resp.error() != null) + .toList(); + + assertThat(errorResponses).isEmpty(); + } + +}