From 50b24b3b14fb1bfa2e0d1f28bdb662106aecdbfe Mon Sep 17 00:00:00 2001 From: Pierre De Rop Date: Fri, 15 Sep 2023 08:44:04 +0200 Subject: [PATCH] ReactorNetty HttpClient sometimes can't get Tomcat early error 400 response (#2864) In scenarios where the ReactorNetty HttpClient writes a substantial POST request to a Tomcat server through HttpClient.send(Mono), and HTTP/1.1 plain is used, and Tomcat promptly responds with an early "400 Bad Request" status before reading the full request body bytes, an issue arises. The HttpClient, in such cases, delays reading the response until after the entire request body has been flushed. Consequently, instead of correctly reporting the early "400 Bad Request" status, it might mistakenly trigger a "Connection prematurely closed BEFORE response" error, because at some point the connection is closed by the server while the client is still writing the request body. To address this problem, this patch activates read interest in the channel when it becomes unwritable. This modification significantly improves the situation in most cases, especially considering Tomcat's default configuration, which includes the "maxSwallowSize" property. With this configuration, Tomcat continues reading request body bytes after sending the "400 Bad Request" response (up to 2 MB) before closing the connection, allowing the HttpClient ample time to consume the "400 Bad Request" status and deliver it to the user as expected. But for servers which close the connection immediately, a TCP/RST might be sent to the client, and in this case the patch cannot always work because TCP/RST is not reliable and any unread data from the HttpClient OS may be dropped. Fixes #2825 --- .../netty/channel/ChannelOperations.java | 8 + .../channel/ChannelOperationsHandler.java | 8 + .../reactor/netty/http/HttpOperations.java | 10 ++ .../AbstractHttpClientMetricsHandler.java | 23 ++- .../http/client/HttpClientOperations.java | 41 +++++ .../test/java/reactor/netty/TomcatServer.java | 39 ++++- .../http/client/HttpClientWithTomcatTest.java | 143 +++++++++++++++++- 7 files changed, 268 insertions(+), 4 deletions(-) diff --git a/reactor-netty-core/src/main/java/reactor/netty/channel/ChannelOperations.java b/reactor-netty-core/src/main/java/reactor/netty/channel/ChannelOperations.java index 1b998f2e4c..0ae1c4e4fa 100644 --- a/reactor-netty-core/src/main/java/reactor/netty/channel/ChannelOperations.java +++ b/reactor-netty-core/src/main/java/reactor/netty/channel/ChannelOperations.java @@ -544,6 +544,14 @@ protected String asDebugLogMessage(Object o) { return o.toString(); } + /** + * React on Channel writability change. + * + * @since 1.0.37 + */ + protected void onWritabilityChanged() { + } + @Override public boolean isPersistent() { return connection.isPersistent(); diff --git a/reactor-netty-core/src/main/java/reactor/netty/channel/ChannelOperationsHandler.java b/reactor-netty-core/src/main/java/reactor/netty/channel/ChannelOperationsHandler.java index 5a5f9c4d65..dd9565d890 100755 --- a/reactor-netty-core/src/main/java/reactor/netty/channel/ChannelOperationsHandler.java +++ b/reactor-netty-core/src/main/java/reactor/netty/channel/ChannelOperationsHandler.java @@ -149,6 +149,14 @@ final public void exceptionCaught(ChannelHandlerContext ctx, Throwable err) { } } + @Override + final public void channelWritabilityChanged(ChannelHandlerContext ctx) { + ChannelOperations ops = ChannelOperations.get(ctx.channel()); + if (ops != null) { + ops.onWritabilityChanged(); + } + } + static void safeRelease(Object msg) { if (msg instanceof ReferenceCounted) { ReferenceCounted referenceCounted = (ReferenceCounted) msg; diff --git a/reactor-netty-http/src/main/java/reactor/netty/http/HttpOperations.java b/reactor-netty-http/src/main/java/reactor/netty/http/HttpOperations.java index 087dabe4b9..4f881c840e 100644 --- a/reactor-netty-http/src/main/java/reactor/netty/http/HttpOperations.java +++ b/reactor-netty-http/src/main/java/reactor/netty/http/HttpOperations.java @@ -335,6 +335,16 @@ protected final boolean markSentBody() { return HTTP_STATE.compareAndSet(this, HEADERS_SENT, BODY_SENT); } + /** + * Has Body been sent + * + * @return true if body has been sent + * @since 1.0.37 + */ + protected final boolean hasSentBody() { + return statusAndHeadersSent == BODY_SENT; + } + /** * Mark the headers and body sent * diff --git a/reactor-netty-http/src/main/java/reactor/netty/http/client/AbstractHttpClientMetricsHandler.java b/reactor-netty-http/src/main/java/reactor/netty/http/client/AbstractHttpClientMetricsHandler.java index 1bfb77a454..228bd5c9c6 100644 --- a/reactor-netty-http/src/main/java/reactor/netty/http/client/AbstractHttpClientMetricsHandler.java +++ b/reactor-netty-http/src/main/java/reactor/netty/http/client/AbstractHttpClientMetricsHandler.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022 VMware, Inc. or its affiliates, All Rights Reserved. + * Copyright (c) 2021-2023 VMware, Inc. or its affiliates, All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -63,6 +63,10 @@ abstract class AbstractHttpClientMetricsHandler extends ChannelDuplexHandler { final Function uriTagValue; + int lastReadSeq; + + int lastWriteSeq; + protected AbstractHttpClientMetricsHandler(@Nullable Function uriTagValue) { this.uriTagValue = uriTagValue; } @@ -77,6 +81,8 @@ protected AbstractHttpClientMetricsHandler(AbstractHttpClientMetricsHandler copy this.path = copy.path; this.status = copy.status; this.uriTagValue = copy.uriTagValue; + this.lastWriteSeq = copy.lastWriteSeq; + this.lastReadSeq = copy.lastReadSeq; } @Override @@ -90,10 +96,15 @@ public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) dataSent += extractProcessedDataFromBuffer(msg); if (msg instanceof LastHttpContent) { + int currentLastWriteSeq = lastWriteSeq; SocketAddress address = ctx.channel().remoteAddress(); promise.addListener(future -> { try { - recordWrite(address); + // Record write, unless channelRead has already done it (because an early full response has been received) + if (currentLastWriteSeq == lastWriteSeq) { + lastWriteSeq = (lastWriteSeq + 1) & 0x7F_FF_FF_FF; + recordWrite(address); + } } catch (RuntimeException e) { if (log.isWarnEnabled()) { @@ -126,6 +137,13 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) { dataReceived += extractProcessedDataFromBuffer(msg); if (msg instanceof LastHttpContent) { + // Detect if we have received an early response before the request has been fully flushed. + // In this case, invoke recordwrite now (because next we will reset all class fields). + lastReadSeq = (lastReadSeq + 1) & 0x7F_FF_FF_FF; + if ((lastReadSeq > lastWriteSeq) || (lastReadSeq == 0 && lastWriteSeq == Integer.MAX_VALUE)) { + lastWriteSeq = (lastWriteSeq + 1) & 0x7F_FF_FF_FF; + recordWrite(ctx.channel().remoteAddress()); + } recordRead(ctx.channel().remoteAddress()); reset(); } @@ -223,5 +241,6 @@ private void reset() { dataSent = 0; dataReceivedTime = 0; dataSentTime = 0; + // don't reset lastWriteSeq and lastReadSeq, which must be incremented for ever } } diff --git a/reactor-netty-http/src/main/java/reactor/netty/http/client/HttpClientOperations.java b/reactor-netty-http/src/main/java/reactor/netty/http/client/HttpClientOperations.java index 78f1316f18..0a72365524 100644 --- a/reactor-netty-http/src/main/java/reactor/netty/http/client/HttpClientOperations.java +++ b/reactor-netty-http/src/main/java/reactor/netty/http/client/HttpClientOperations.java @@ -60,6 +60,7 @@ import io.netty.handler.codec.http.multipart.HttpDataFactory; import io.netty.handler.codec.http.multipart.HttpPostRequestEncoder; import io.netty.handler.codec.http.websocketx.extensions.compression.WebSocketClientCompressionHandler; +import io.netty.handler.codec.http2.Http2StreamChannel; import io.netty.handler.stream.ChunkedWriteHandler; import io.netty.handler.timeout.ReadTimeoutHandler; import io.netty.util.ReferenceCountUtil; @@ -524,6 +525,46 @@ else if (version.equals(HttpVersion.HTTP_1_1)) { throw new IllegalStateException(version.protocolName() + " not supported"); } + /** + * React on channel unwritability event while the http client request is being written. + * + *

When using plain HTTP/1.1 and {@code HttpClient.send(Mono)}, if the socket becomes unwritable while writing, + * we need to request for reads. This is necessary to read any early server response, such as a 400 bad request followed + * by a socket close, while the request is still being written. Else, a "premature close exception before response" may be reported + * to the user, causing confusion about the server's early response. + * + *

There is no need to request for reading in other cases + * (H2/H2C/H1S/WebSocket), because in these cases the read interest has already been requested, or auto-read is enabled + * + *

Important notes: + *

+ * - If the connection is unwritable and {@code send(Flux)} has been used, then {@code hasSentBody()} will + * always return false, because when {@code send(Flux)} is used, {@code hasSentBody()} can only return true + * if the request is fully written (see {@link #onOutboundComplete()} method which invokes {@code markSentBody()} + * and sets the state to BODY_SENT). + * So if channel is unwritable and {@code hasSentBody()} returns true, it means that {@code send(Mono)} has + * been used (see {@link HttpOperations#send(Publisher)} where {@code markSentHeaderAndBody(b)} is setting + * the state to BODY_SENT when the Publisher is a Mono). + * + *

- When the channel is unwritable, a channel read() has already been requested or is in auto-read if: + *

  • Secure mode is used (Netty SslHandler requests read() when flushing).
  • + *
  • HTTP2 is used.
  • + *
  • WebSocket is used.
  • + *
+ * + *

See GH-2825 for more info + */ + @Override + protected void onWritabilityChanged() { + if (!isSecure && + !channel().isWritable() && !channel().config().isAutoRead() && + hasSentBody() && + !(channel() instanceof Http2StreamChannel) && + !isWebsocket()) { + channel().read(); + } + } + @Override protected void afterMarkSentHeaders() { //Noop diff --git a/reactor-netty-http/src/test/java/reactor/netty/TomcatServer.java b/reactor-netty-http/src/test/java/reactor/netty/TomcatServer.java index 4da3bc9a23..dc6bcc4d9a 100644 --- a/reactor-netty-http/src/test/java/reactor/netty/TomcatServer.java +++ b/reactor-netty-http/src/test/java/reactor/netty/TomcatServer.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022 VMware, Inc. or its affiliates, All Rights Reserved. + * Copyright (c) 2019-2023 VMware, Inc. or its affiliates, All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,6 +27,7 @@ import javax.servlet.http.Part; import java.io.File; import java.io.IOException; +import java.io.InputStream; import java.io.PrintWriter; import java.util.Collection; @@ -35,6 +36,8 @@ */ public class TomcatServer { static final String TOMCAT_BASE_DIR = "./build/tomcat"; + public static final String TOO_LARGE = "Request payload too large"; + public static final int PAYLOAD_MAX = 5000000; final Tomcat tomcat; @@ -82,6 +85,7 @@ public void createDefaultContext() { addServlet(ctx, new StatusServlet(), "/status/*"); addServlet(ctx, new MultipartServlet(), "/multipart") .setMultipartConfigElement(new MultipartConfigElement("")); + addServlet(ctx, new PayloadSizeServlet(), "/payload-size"); } public void createContext(HttpServlet servlet, String mapping) { @@ -163,4 +167,37 @@ protected void service(HttpServletRequest req, HttpServletResponse resp) throws } } } + + static final class PayloadSizeServlet extends HttpServlet { + + @Override + protected void service(HttpServletRequest req, HttpServletResponse resp) throws IOException { + InputStream in = req.getInputStream(); + int count = 0; + int n; + + while ((n = in.read()) != -1) { + count += n; + if (count >= PAYLOAD_MAX) { + // By default, Tomcat is configured with maxSwallowSize=2 MB (see https://tomcat.apache.org/tomcat-9.0-doc/config/http.html) + // This means that once the 400 bad request is sent, the client will still be able to continue writing (if it is currently writing) + // up to 2 MB. So, it is very likely that the client will be blocked and it will then be able to consume the 400 bad request and + // close itself the connection. + sendResponse(resp, TOO_LARGE, HttpServletResponse.SC_BAD_REQUEST); + return; + } + } + + sendResponse(resp, String.valueOf(count), HttpServletResponse.SC_OK); + } + + private void sendResponse(HttpServletResponse resp, String message, int status) throws IOException { + resp.setStatus(status); + resp.setHeader("Transfer-Encoding", "chunked"); + resp.setHeader("Content-Type", "text/plain"); + PrintWriter out = resp.getWriter(); + out.print(message); + out.flush(); + } + } } diff --git a/reactor-netty-http/src/test/java/reactor/netty/http/client/HttpClientWithTomcatTest.java b/reactor-netty-http/src/test/java/reactor/netty/http/client/HttpClientWithTomcatTest.java index a4ff369b9c..bdbfc18e13 100644 --- a/reactor-netty-http/src/test/java/reactor/netty/http/client/HttpClientWithTomcatTest.java +++ b/reactor-netty-http/src/test/java/reactor/netty/http/client/HttpClientWithTomcatTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022 VMware, Inc. or its affiliates, All Rights Reserved. + * Copyright (c) 2019-2023 VMware, Inc. or its affiliates, All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,7 +15,10 @@ */ package reactor.netty.http.client; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaderValues; import io.netty.handler.codec.http.HttpHeaders; @@ -26,21 +29,30 @@ import io.netty.handler.codec.http.multipart.HttpData; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Named; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.netty.ByteBufFlux; import reactor.netty.TomcatServer; import reactor.netty.resources.ConnectionProvider; +import reactor.test.StepVerifier; import reactor.util.function.Tuple2; import reactor.util.function.Tuples; import java.io.InputStream; import java.lang.reflect.Field; +import java.net.SocketAddress; +import java.nio.charset.Charset; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.time.Duration; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.concurrent.CountDownLatch; @@ -48,6 +60,8 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiConsumer; +import java.util.function.Supplier; +import java.util.stream.Stream; import static org.assertj.core.api.Assertions.assertThat; import static reactor.netty.http.client.HttpClientOperations.SendForm.DEFAULT_FACTORY; @@ -57,6 +71,8 @@ */ class HttpClientWithTomcatTest { private static TomcatServer tomcat; + private static final byte[] PAYLOAD = String.join("", Collections.nCopies(TomcatServer.PAYLOAD_MAX + (1024 * 1024), "X")) + .getBytes(Charset.defaultCharset()); @BeforeAll static void startTomcat() throws Exception { @@ -317,6 +333,48 @@ void contentHeader() { fixed.dispose(); } + static Stream testIssue2825Args() { + Supplier> postMono = () -> Mono.just(Unpooled.wrappedBuffer(PAYLOAD)); + Supplier> postFlux = () -> Flux.just(Unpooled.wrappedBuffer(PAYLOAD)); + + return Stream.of( + Arguments.of(Named.of("postMono", postMono), Named.of("bytes", PAYLOAD.length)), + Arguments.of(Named.of("postFlux", postFlux), Named.of("bytes", PAYLOAD.length)) + ); + } + + @ParameterizedTest + @MethodSource("testIssue2825Args") + void testIssue2825(Supplier> payload, long bytesToSend) { + AtomicReference serverAddress = new AtomicReference<>(); + HttpClient client = HttpClient.create() + .port(getPort()) + .wiretap(false) + .metrics(true, ClientMetricsRecorder::reset) + .doOnConnected(conn -> serverAddress.set(conn.address())); + + StepVerifier.create(client + .headers(hdr -> hdr.set("Content-Type", "text/plain")) + .post() + .uri("/payload-size") + .send(payload.get()) + .response((r, buf) -> buf.aggregate().asString().zipWith(Mono.just(r)))) + .expectNextMatches(tuple -> TomcatServer.TOO_LARGE.equals(tuple.getT1()) + && tuple.getT2().status().equals(HttpResponseStatus.BAD_REQUEST)) + .expectComplete() + .verify(Duration.ofSeconds(30)); + + assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentTimeMethod).isEqualTo("POST"); + assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentTimeTime).isNotNull(); + assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentTimeTime.isZero()).isFalse(); + assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentTimeUri).isEqualTo("/payload-size"); + assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentTimeRemoteAddr).isEqualTo(serverAddress.get()); + + assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentRemoteAddr).isEqualTo(serverAddress.get()); + assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentUri).isEqualTo("/payload-size"); + assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentBytes).isEqualTo(bytesToSend); + } + private int getPort() { return tomcat.port(); } @@ -324,4 +382,87 @@ private int getPort() { private String getURL() { return "http://localhost:" + tomcat.port(); } + + /** + * This Custom metrics recorder checks that the {@link AbstractHttpClientMetricsHandler#recordWrite(SocketAddress)} is properly invoked by + * (see {@link AbstractHttpClientMetricsHandler#channelRead(ChannelHandlerContext, Object)}) when + * an early response is received while the corresponding request it still being written. + */ + static final class ClientMetricsRecorder implements HttpClientMetricsRecorder { + + static final ClientMetricsRecorder INSTANCE = new ClientMetricsRecorder(); + volatile SocketAddress recordDataSentTimeRemoteAddr; + volatile String recordDataSentTimeUri; + volatile String recordDataSentTimeMethod; + volatile Duration recordDataSentTimeTime; + volatile SocketAddress recordDataSentRemoteAddr; + volatile String recordDataSentUri; + volatile long recordDataSentBytes; + + static ClientMetricsRecorder reset() { + INSTANCE.recordDataSentTimeRemoteAddr = null; + INSTANCE.recordDataSentTimeUri = null; + INSTANCE.recordDataSentTimeMethod = null; + INSTANCE.recordDataSentTimeTime = null; + INSTANCE.recordDataSentRemoteAddr = null; + INSTANCE.recordDataSentUri = null; + INSTANCE.recordDataSentBytes = -1; + return INSTANCE; + } + + @Override + public void recordDataReceived(SocketAddress remoteAddress, long bytes) { + } + + @Override + public void recordDataSent(SocketAddress remoteAddress, long bytes) { + } + + @Override + public void incrementErrorsCount(SocketAddress remoteAddress) { + } + + @Override + public void recordTlsHandshakeTime(SocketAddress remoteAddress, Duration time, String status) { + } + + @Override + public void recordConnectTime(SocketAddress remoteAddress, Duration time, String status) { + } + + @Override + public void recordResolveAddressTime(SocketAddress remoteAddress, Duration time, String status) { + } + + @Override + public void recordDataReceived(SocketAddress remoteAddress, String uri, long bytes) { + } + + @Override + public void recordDataSent(SocketAddress remoteAddress, String uri, long bytes) { + this.recordDataSentRemoteAddr = remoteAddress; + this.recordDataSentUri = uri; + this.recordDataSentBytes = bytes; + } + + @Override + public void incrementErrorsCount(SocketAddress remoteAddress, String uri) { + } + + @Override + public void recordDataReceivedTime(SocketAddress remoteAddress, String uri, String method, String status, Duration time) { + } + + @Override + public void recordDataSentTime(SocketAddress remoteAddress, String uri, String method, Duration time) { + this.recordDataSentTimeRemoteAddr = remoteAddress; + this.recordDataSentTimeUri = uri; + this.recordDataSentTimeMethod = method; + this.recordDataSentTimeTime = time; + } + + @Override + public void recordResponseTime(SocketAddress remoteAddress, String uri, String method, String status, Duration time) { + } + } }