diff --git a/reactor-netty-core/src/main/java/reactor/netty/channel/ChannelOperations.java b/reactor-netty-core/src/main/java/reactor/netty/channel/ChannelOperations.java index 0ec625fccf..5c9f7b77ae 100644 --- a/reactor-netty-core/src/main/java/reactor/netty/channel/ChannelOperations.java +++ b/reactor-netty-core/src/main/java/reactor/netty/channel/ChannelOperations.java @@ -551,6 +551,14 @@ protected String asDebugLogMessage(Object o) { return o.toString(); } + /** + * React on Channel writability change. + * + * @since 1.0.37 + */ + protected void onWritabilityChanged() { + } + @Override public boolean isPersistent() { return connection.isPersistent(); diff --git a/reactor-netty-core/src/main/java/reactor/netty/channel/ChannelOperationsHandler.java b/reactor-netty-core/src/main/java/reactor/netty/channel/ChannelOperationsHandler.java index 5a5f9c4d65..dd9565d890 100755 --- a/reactor-netty-core/src/main/java/reactor/netty/channel/ChannelOperationsHandler.java +++ b/reactor-netty-core/src/main/java/reactor/netty/channel/ChannelOperationsHandler.java @@ -149,6 +149,14 @@ final public void exceptionCaught(ChannelHandlerContext ctx, Throwable err) { } } + @Override + final public void channelWritabilityChanged(ChannelHandlerContext ctx) { + ChannelOperations ops = ChannelOperations.get(ctx.channel()); + if (ops != null) { + ops.onWritabilityChanged(); + } + } + static void safeRelease(Object msg) { if (msg instanceof ReferenceCounted) { ReferenceCounted referenceCounted = (ReferenceCounted) msg; diff --git a/reactor-netty-http/src/main/java/reactor/netty/http/HttpOperations.java b/reactor-netty-http/src/main/java/reactor/netty/http/HttpOperations.java index 087dabe4b9..4f881c840e 100644 --- a/reactor-netty-http/src/main/java/reactor/netty/http/HttpOperations.java +++ b/reactor-netty-http/src/main/java/reactor/netty/http/HttpOperations.java @@ -335,6 +335,16 @@ protected final boolean markSentBody() { return HTTP_STATE.compareAndSet(this, HEADERS_SENT, BODY_SENT); } + /** + * Has Body been sent + * + * @return true if body has been sent + * @since 1.0.37 + */ + protected final boolean hasSentBody() { + return statusAndHeadersSent == BODY_SENT; + } + /** * Mark the headers and body sent * diff --git a/reactor-netty-http/src/main/java/reactor/netty/http/client/AbstractHttpClientMetricsHandler.java b/reactor-netty-http/src/main/java/reactor/netty/http/client/AbstractHttpClientMetricsHandler.java index 7efe2d7c74..6950fef345 100644 --- a/reactor-netty-http/src/main/java/reactor/netty/http/client/AbstractHttpClientMetricsHandler.java +++ b/reactor-netty-http/src/main/java/reactor/netty/http/client/AbstractHttpClientMetricsHandler.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022 VMware, Inc. or its affiliates, All Rights Reserved. + * Copyright (c) 2021-2023 VMware, Inc. or its affiliates, All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -64,6 +64,10 @@ abstract class AbstractHttpClientMetricsHandler extends ChannelDuplexHandler { final Function uriTagValue; + int lastReadSeq; + + int lastWriteSeq; + protected AbstractHttpClientMetricsHandler(@Nullable Function uriTagValue) { this.uriTagValue = uriTagValue; } @@ -78,6 +82,8 @@ protected AbstractHttpClientMetricsHandler(AbstractHttpClientMetricsHandler copy this.path = copy.path; this.status = copy.status; this.uriTagValue = copy.uriTagValue; + this.lastWriteSeq = copy.lastWriteSeq; + this.lastReadSeq = copy.lastReadSeq; } @Override @@ -91,10 +97,15 @@ public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) dataSent += extractProcessedDataFromBuffer(msg); if (msg instanceof LastHttpContent) { + int currentLastWriteSeq = lastWriteSeq; SocketAddress address = ctx.channel().remoteAddress(); promise.addListener(future -> { try { - recordWrite(address); + // Record write, unless channelRead has already done it (because an early full response has been received) + if (currentLastWriteSeq == lastWriteSeq) { + lastWriteSeq = (lastWriteSeq + 1) & 0x7F_FF_FF_FF; + recordWrite(address); + } } catch (RuntimeException e) { if (log.isWarnEnabled()) { @@ -128,6 +139,13 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) { dataReceived += extractProcessedDataFromBuffer(msg); if (msg instanceof LastHttpContent) { + // Detect if we have received an early response before the request has been fully flushed. + // In this case, invoke recordwrite now (because next we will reset all class fields). + lastReadSeq = (lastReadSeq + 1) & 0x7F_FF_FF_FF; + if ((lastReadSeq > lastWriteSeq) || (lastReadSeq == 0 && lastWriteSeq == Integer.MAX_VALUE)) { + lastWriteSeq = (lastWriteSeq + 1) & 0x7F_FF_FF_FF; + recordWrite(ctx.channel().remoteAddress()); + } recordRead(ctx.channel()); reset(); } @@ -217,6 +235,7 @@ protected void reset() { dataSent = 0; dataReceivedTime = 0; dataSentTime = 0; + // don't reset lastWriteSeq and lastReadSeq, which must be incremented for ever } protected void startRead(HttpResponse msg) { diff --git a/reactor-netty-http/src/main/java/reactor/netty/http/client/HttpClientOperations.java b/reactor-netty-http/src/main/java/reactor/netty/http/client/HttpClientOperations.java index 78f1316f18..0a72365524 100644 --- a/reactor-netty-http/src/main/java/reactor/netty/http/client/HttpClientOperations.java +++ b/reactor-netty-http/src/main/java/reactor/netty/http/client/HttpClientOperations.java @@ -60,6 +60,7 @@ import io.netty.handler.codec.http.multipart.HttpDataFactory; import io.netty.handler.codec.http.multipart.HttpPostRequestEncoder; import io.netty.handler.codec.http.websocketx.extensions.compression.WebSocketClientCompressionHandler; +import io.netty.handler.codec.http2.Http2StreamChannel; import io.netty.handler.stream.ChunkedWriteHandler; import io.netty.handler.timeout.ReadTimeoutHandler; import io.netty.util.ReferenceCountUtil; @@ -524,6 +525,46 @@ else if (version.equals(HttpVersion.HTTP_1_1)) { throw new IllegalStateException(version.protocolName() + " not supported"); } + /** + * React on channel unwritability event while the http client request is being written. + * + *

When using plain HTTP/1.1 and {@code HttpClient.send(Mono)}, if the socket becomes unwritable while writing, + * we need to request for reads. This is necessary to read any early server response, such as a 400 bad request followed + * by a socket close, while the request is still being written. Else, a "premature close exception before response" may be reported + * to the user, causing confusion about the server's early response. + * + *

There is no need to request for reading in other cases + * (H2/H2C/H1S/WebSocket), because in these cases the read interest has already been requested, or auto-read is enabled + * + *

Important notes: + *

+ * - If the connection is unwritable and {@code send(Flux)} has been used, then {@code hasSentBody()} will + * always return false, because when {@code send(Flux)} is used, {@code hasSentBody()} can only return true + * if the request is fully written (see {@link #onOutboundComplete()} method which invokes {@code markSentBody()} + * and sets the state to BODY_SENT). + * So if channel is unwritable and {@code hasSentBody()} returns true, it means that {@code send(Mono)} has + * been used (see {@link HttpOperations#send(Publisher)} where {@code markSentHeaderAndBody(b)} is setting + * the state to BODY_SENT when the Publisher is a Mono). + * + *

- When the channel is unwritable, a channel read() has already been requested or is in auto-read if: + *

+ * + *

See GH-2825 for more info + */ + @Override + protected void onWritabilityChanged() { + if (!isSecure && + !channel().isWritable() && !channel().config().isAutoRead() && + hasSentBody() && + !(channel() instanceof Http2StreamChannel) && + !isWebsocket()) { + channel().read(); + } + } + @Override protected void afterMarkSentHeaders() { //Noop diff --git a/reactor-netty-http/src/test/java/reactor/netty/TomcatServer.java b/reactor-netty-http/src/test/java/reactor/netty/TomcatServer.java index 4da3bc9a23..dc6bcc4d9a 100644 --- a/reactor-netty-http/src/test/java/reactor/netty/TomcatServer.java +++ b/reactor-netty-http/src/test/java/reactor/netty/TomcatServer.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022 VMware, Inc. or its affiliates, All Rights Reserved. + * Copyright (c) 2019-2023 VMware, Inc. or its affiliates, All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,6 +27,7 @@ import javax.servlet.http.Part; import java.io.File; import java.io.IOException; +import java.io.InputStream; import java.io.PrintWriter; import java.util.Collection; @@ -35,6 +36,8 @@ */ public class TomcatServer { static final String TOMCAT_BASE_DIR = "./build/tomcat"; + public static final String TOO_LARGE = "Request payload too large"; + public static final int PAYLOAD_MAX = 5000000; final Tomcat tomcat; @@ -82,6 +85,7 @@ public void createDefaultContext() { addServlet(ctx, new StatusServlet(), "/status/*"); addServlet(ctx, new MultipartServlet(), "/multipart") .setMultipartConfigElement(new MultipartConfigElement("")); + addServlet(ctx, new PayloadSizeServlet(), "/payload-size"); } public void createContext(HttpServlet servlet, String mapping) { @@ -163,4 +167,37 @@ protected void service(HttpServletRequest req, HttpServletResponse resp) throws } } } + + static final class PayloadSizeServlet extends HttpServlet { + + @Override + protected void service(HttpServletRequest req, HttpServletResponse resp) throws IOException { + InputStream in = req.getInputStream(); + int count = 0; + int n; + + while ((n = in.read()) != -1) { + count += n; + if (count >= PAYLOAD_MAX) { + // By default, Tomcat is configured with maxSwallowSize=2 MB (see https://tomcat.apache.org/tomcat-9.0-doc/config/http.html) + // This means that once the 400 bad request is sent, the client will still be able to continue writing (if it is currently writing) + // up to 2 MB. So, it is very likely that the client will be blocked and it will then be able to consume the 400 bad request and + // close itself the connection. + sendResponse(resp, TOO_LARGE, HttpServletResponse.SC_BAD_REQUEST); + return; + } + } + + sendResponse(resp, String.valueOf(count), HttpServletResponse.SC_OK); + } + + private void sendResponse(HttpServletResponse resp, String message, int status) throws IOException { + resp.setStatus(status); + resp.setHeader("Transfer-Encoding", "chunked"); + resp.setHeader("Content-Type", "text/plain"); + PrintWriter out = resp.getWriter(); + out.print(message); + out.flush(); + } + } } diff --git a/reactor-netty-http/src/test/java/reactor/netty/http/client/HttpClientWithTomcatTest.java b/reactor-netty-http/src/test/java/reactor/netty/http/client/HttpClientWithTomcatTest.java index a4ff369b9c..bdbfc18e13 100644 --- a/reactor-netty-http/src/test/java/reactor/netty/http/client/HttpClientWithTomcatTest.java +++ b/reactor-netty-http/src/test/java/reactor/netty/http/client/HttpClientWithTomcatTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022 VMware, Inc. or its affiliates, All Rights Reserved. + * Copyright (c) 2019-2023 VMware, Inc. or its affiliates, All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,7 +15,10 @@ */ package reactor.netty.http.client; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaderValues; import io.netty.handler.codec.http.HttpHeaders; @@ -26,21 +29,30 @@ import io.netty.handler.codec.http.multipart.HttpData; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Named; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.netty.ByteBufFlux; import reactor.netty.TomcatServer; import reactor.netty.resources.ConnectionProvider; +import reactor.test.StepVerifier; import reactor.util.function.Tuple2; import reactor.util.function.Tuples; import java.io.InputStream; import java.lang.reflect.Field; +import java.net.SocketAddress; +import java.nio.charset.Charset; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.time.Duration; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.concurrent.CountDownLatch; @@ -48,6 +60,8 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiConsumer; +import java.util.function.Supplier; +import java.util.stream.Stream; import static org.assertj.core.api.Assertions.assertThat; import static reactor.netty.http.client.HttpClientOperations.SendForm.DEFAULT_FACTORY; @@ -57,6 +71,8 @@ */ class HttpClientWithTomcatTest { private static TomcatServer tomcat; + private static final byte[] PAYLOAD = String.join("", Collections.nCopies(TomcatServer.PAYLOAD_MAX + (1024 * 1024), "X")) + .getBytes(Charset.defaultCharset()); @BeforeAll static void startTomcat() throws Exception { @@ -317,6 +333,48 @@ void contentHeader() { fixed.dispose(); } + static Stream testIssue2825Args() { + Supplier> postMono = () -> Mono.just(Unpooled.wrappedBuffer(PAYLOAD)); + Supplier> postFlux = () -> Flux.just(Unpooled.wrappedBuffer(PAYLOAD)); + + return Stream.of( + Arguments.of(Named.of("postMono", postMono), Named.of("bytes", PAYLOAD.length)), + Arguments.of(Named.of("postFlux", postFlux), Named.of("bytes", PAYLOAD.length)) + ); + } + + @ParameterizedTest + @MethodSource("testIssue2825Args") + void testIssue2825(Supplier> payload, long bytesToSend) { + AtomicReference serverAddress = new AtomicReference<>(); + HttpClient client = HttpClient.create() + .port(getPort()) + .wiretap(false) + .metrics(true, ClientMetricsRecorder::reset) + .doOnConnected(conn -> serverAddress.set(conn.address())); + + StepVerifier.create(client + .headers(hdr -> hdr.set("Content-Type", "text/plain")) + .post() + .uri("/payload-size") + .send(payload.get()) + .response((r, buf) -> buf.aggregate().asString().zipWith(Mono.just(r)))) + .expectNextMatches(tuple -> TomcatServer.TOO_LARGE.equals(tuple.getT1()) + && tuple.getT2().status().equals(HttpResponseStatus.BAD_REQUEST)) + .expectComplete() + .verify(Duration.ofSeconds(30)); + + assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentTimeMethod).isEqualTo("POST"); + assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentTimeTime).isNotNull(); + assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentTimeTime.isZero()).isFalse(); + assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentTimeUri).isEqualTo("/payload-size"); + assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentTimeRemoteAddr).isEqualTo(serverAddress.get()); + + assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentRemoteAddr).isEqualTo(serverAddress.get()); + assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentUri).isEqualTo("/payload-size"); + assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentBytes).isEqualTo(bytesToSend); + } + private int getPort() { return tomcat.port(); } @@ -324,4 +382,87 @@ private int getPort() { private String getURL() { return "http://localhost:" + tomcat.port(); } + + /** + * This Custom metrics recorder checks that the {@link AbstractHttpClientMetricsHandler#recordWrite(SocketAddress)} is properly invoked by + * (see {@link AbstractHttpClientMetricsHandler#channelRead(ChannelHandlerContext, Object)}) when + * an early response is received while the corresponding request it still being written. + */ + static final class ClientMetricsRecorder implements HttpClientMetricsRecorder { + + static final ClientMetricsRecorder INSTANCE = new ClientMetricsRecorder(); + volatile SocketAddress recordDataSentTimeRemoteAddr; + volatile String recordDataSentTimeUri; + volatile String recordDataSentTimeMethod; + volatile Duration recordDataSentTimeTime; + volatile SocketAddress recordDataSentRemoteAddr; + volatile String recordDataSentUri; + volatile long recordDataSentBytes; + + static ClientMetricsRecorder reset() { + INSTANCE.recordDataSentTimeRemoteAddr = null; + INSTANCE.recordDataSentTimeUri = null; + INSTANCE.recordDataSentTimeMethod = null; + INSTANCE.recordDataSentTimeTime = null; + INSTANCE.recordDataSentRemoteAddr = null; + INSTANCE.recordDataSentUri = null; + INSTANCE.recordDataSentBytes = -1; + return INSTANCE; + } + + @Override + public void recordDataReceived(SocketAddress remoteAddress, long bytes) { + } + + @Override + public void recordDataSent(SocketAddress remoteAddress, long bytes) { + } + + @Override + public void incrementErrorsCount(SocketAddress remoteAddress) { + } + + @Override + public void recordTlsHandshakeTime(SocketAddress remoteAddress, Duration time, String status) { + } + + @Override + public void recordConnectTime(SocketAddress remoteAddress, Duration time, String status) { + } + + @Override + public void recordResolveAddressTime(SocketAddress remoteAddress, Duration time, String status) { + } + + @Override + public void recordDataReceived(SocketAddress remoteAddress, String uri, long bytes) { + } + + @Override + public void recordDataSent(SocketAddress remoteAddress, String uri, long bytes) { + this.recordDataSentRemoteAddr = remoteAddress; + this.recordDataSentUri = uri; + this.recordDataSentBytes = bytes; + } + + @Override + public void incrementErrorsCount(SocketAddress remoteAddress, String uri) { + } + + @Override + public void recordDataReceivedTime(SocketAddress remoteAddress, String uri, String method, String status, Duration time) { + } + + @Override + public void recordDataSentTime(SocketAddress remoteAddress, String uri, String method, Duration time) { + this.recordDataSentTimeRemoteAddr = remoteAddress; + this.recordDataSentTimeUri = uri; + this.recordDataSentTimeMethod = method; + this.recordDataSentTimeTime = time; + } + + @Override + public void recordResponseTime(SocketAddress remoteAddress, String uri, String method, String status, Duration time) { + } + } }