diff --git a/webclient/api/src/main/java/io/helidon/webclient/api/HttpClient.java b/webclient/api/src/main/java/io/helidon/webclient/api/HttpClient.java index a08aaa2c6fd..6024757a18a 100644 --- a/webclient/api/src/main/java/io/helidon/webclient/api/HttpClient.java +++ b/webclient/api/src/main/java/io/helidon/webclient/api/HttpClient.java @@ -23,7 +23,7 @@ * * @param type of the client request */ -public interface HttpClient> { +public interface HttpClient> extends ReleasableResource { /** * Create a request for a method. * @@ -32,6 +32,13 @@ public interface HttpClient> { */ REQ method(Method method); + /** + * Gracefully close all opened client specific connections. + */ + default void closeResource() { + // Do nothing by default + } + /** * Shortcut for get method with a path. * diff --git a/webclient/api/src/main/java/io/helidon/webclient/api/LoomClient.java b/webclient/api/src/main/java/io/helidon/webclient/api/LoomClient.java index 86a6ffd3e29..8039bb12c46 100644 --- a/webclient/api/src/main/java/io/helidon/webclient/api/LoomClient.java +++ b/webclient/api/src/main/java/io/helidon/webclient/api/LoomClient.java @@ -151,6 +151,13 @@ public HttpClientRequest method(Method method) { tcpProtocolIds); } + @Override + public void closeResource() { + for (ProtocolSpi o : List.copyOf(clientSpiByProtocol.values())) { + o.spi().releaseResource(); + } + } + @Override public T client(Protocol protocol, C protocolConfig) { return protocol.provider().protocol(this, protocolConfig); diff --git a/webclient/api/src/main/java/io/helidon/webclient/spi/ClientConnectionCache.java b/webclient/api/src/main/java/io/helidon/webclient/spi/ClientConnectionCache.java new file mode 100644 index 00000000000..1f4ab330bad --- /dev/null +++ b/webclient/api/src/main/java/io/helidon/webclient/spi/ClientConnectionCache.java @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2023 Oracle and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.helidon.webclient.spi; + +import io.helidon.webclient.api.ReleasableResource; + +import static java.lang.System.Logger.Level; + +/** + * Client connection cache with release shutdown hook to provide graceful shutdown. + */ +public abstract class ClientConnectionCache implements ReleasableResource { + + private static final System.Logger LOGGER = System.getLogger(ClientConnectionCache.class.getName()); + + protected ClientConnectionCache(boolean shared) { + if (shared) { + Runtime.getRuntime().addShutdownHook(new Thread(this::onShutdown)); + } + } + + private void onShutdown() { + if (LOGGER.isLoggable(Level.DEBUG)) { + LOGGER.log(Level.DEBUG, "Gracefully closing connections in client connection cache."); + } + this.releaseResource(); + } +} diff --git a/webclient/api/src/main/java/io/helidon/webclient/spi/HttpClientSpi.java b/webclient/api/src/main/java/io/helidon/webclient/spi/HttpClientSpi.java index 3956badc3d8..a5e3769700c 100644 --- a/webclient/api/src/main/java/io/helidon/webclient/spi/HttpClientSpi.java +++ b/webclient/api/src/main/java/io/helidon/webclient/spi/HttpClientSpi.java @@ -19,11 +19,12 @@ import io.helidon.webclient.api.ClientRequest; import io.helidon.webclient.api.ClientUri; import io.helidon.webclient.api.FullClientRequest; +import io.helidon.webclient.api.ReleasableResource; /** * Integration for HTTP versions to provide a single API. */ -public interface HttpClientSpi { +public interface HttpClientSpi extends ReleasableResource { /** * Return whether this HTTP version can handle the provided request. *

diff --git a/webclient/http1/src/main/java/io/helidon/webclient/http1/Http1ClientImpl.java b/webclient/http1/src/main/java/io/helidon/webclient/http1/Http1ClientImpl.java index 101b3980945..f06c4300837 100644 --- a/webclient/http1/src/main/java/io/helidon/webclient/http1/Http1ClientImpl.java +++ b/webclient/http1/src/main/java/io/helidon/webclient/http1/Http1ClientImpl.java @@ -28,6 +28,7 @@ class Http1ClientImpl implements Http1Client, HttpClientSpi { private final Http1ClientConfig clientConfig; private final Http1ClientProtocolConfig protocolConfig; private final Http1ConnectionCache connectionCache; + private final Http1ConnectionCache clientCache; Http1ClientImpl(WebClient webClient, Http1ClientConfig clientConfig) { this.webClient = webClient; @@ -35,8 +36,10 @@ class Http1ClientImpl implements Http1Client, HttpClientSpi { this.protocolConfig = clientConfig.protocolConfig(); if (clientConfig.shareConnectionCache()) { this.connectionCache = Http1ConnectionCache.shared(); + this.clientCache = null; } else { this.connectionCache = Http1ConnectionCache.create(); + this.clientCache = connectionCache; } } @@ -86,6 +89,13 @@ public ClientRequest clientRequest(FullClientRequest clientRequest, Client .fragment(clientUri.fragment()); } + @Override + public void closeResource() { + if (clientCache != null) { + this.clientCache.closeResource(); + } + } + WebClient webClient() { return webClient; } diff --git a/webclient/http1/src/main/java/io/helidon/webclient/http1/Http1ConnectionCache.java b/webclient/http1/src/main/java/io/helidon/webclient/http1/Http1ConnectionCache.java index 74eb18af7e7..e596786d0bb 100644 --- a/webclient/http1/src/main/java/io/helidon/webclient/http1/Http1ConnectionCache.java +++ b/webclient/http1/src/main/java/io/helidon/webclient/http1/Http1ConnectionCache.java @@ -17,11 +17,13 @@ package io.helidon.webclient.http1; import java.time.Duration; +import java.util.Collection; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.LinkedBlockingDeque; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import io.helidon.common.tls.Tls; import io.helidon.http.ClientRequestHeaders; @@ -33,27 +35,33 @@ import io.helidon.webclient.api.Proxy; import io.helidon.webclient.api.TcpClientConnection; import io.helidon.webclient.api.WebClient; +import io.helidon.webclient.spi.ClientConnectionCache; import static java.lang.System.Logger.Level.DEBUG; /** * Cache of HTTP/1.1 connections for keep alive. */ -class Http1ConnectionCache { +class Http1ConnectionCache extends ClientConnectionCache { private static final System.Logger LOGGER = System.getLogger(Http1ConnectionCache.class.getName()); private static final Tls NO_TLS = Tls.builder().enabled(false).build(); private static final String HTTPS = "https"; - private static final Http1ConnectionCache SHARED = create(); + private static final Http1ConnectionCache SHARED = new Http1ConnectionCache(true); private static final List ALPN_ID = List.of(Http1Client.PROTOCOL_ID); private static final Duration QUEUE_TIMEOUT = Duration.ofMillis(10); private final Map> cache = new ConcurrentHashMap<>(); + private final AtomicBoolean closed = new AtomicBoolean(); + + protected Http1ConnectionCache(boolean shared) { + super(shared); + } static Http1ConnectionCache shared() { return SHARED; } static Http1ConnectionCache create() { - return new Http1ConnectionCache(); + return new Http1ConnectionCache(false); } ClientConnection connection(Http1ClientImpl http1Client, @@ -71,6 +79,16 @@ ClientConnection connection(Http1ClientImpl http1Client, } } + @Override + public void closeResource() { + if (closed.getAndSet(true)) { + return; + } + cache.values().stream() + .flatMap(Collection::stream) + .forEach(TcpClientConnection::closeResource); + } + private boolean handleKeepAlive(boolean defaultKeepAlive, WritableHeaders headers) { if (headers.contains(HeaderValues.CONNECTION_CLOSE)) { return false; @@ -90,6 +108,11 @@ private ClientConnection keepAliveConnection(Http1ClientImpl http1Client, Tls tls, ClientUri uri, Proxy proxy) { + + if (closed.get()) { + throw new IllegalStateException("Connection cache is closed"); + } + Http1ClientConfig clientConfig = http1Client.clientConfig(); ConnectionKey connectionKey = new ConnectionKey(uri.scheme(), diff --git a/webclient/http2/src/main/java/io/helidon/webclient/http2/Http2ClientConnection.java b/webclient/http2/src/main/java/io/helidon/webclient/http2/Http2ClientConnection.java index dc800d63461..71a2ff21139 100644 --- a/webclient/http2/src/main/java/io/helidon/webclient/http2/Http2ClientConnection.java +++ b/webclient/http2/src/main/java/io/helidon/webclient/http2/Http2ClientConnection.java @@ -25,6 +25,7 @@ import java.util.concurrent.Future; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReadWriteLock; import java.util.concurrent.locks.ReentrantReadWriteLock; @@ -85,8 +86,7 @@ class Http2ClientConnection { private Http2Settings serverSettings = Http2Settings.builder() .build(); private Future handleTask; - - private volatile boolean closed = false; + private final AtomicReference state = new AtomicReference<>(State.OPEN); Http2ClientConnection(Http2ClientImpl http2Client, ClientConnection connection) { this.protocolConfig = http2Client.protocolConfig(); @@ -177,7 +177,7 @@ Http2ClientStream tryStream(Http2StreamConfig config) { } boolean closed() { - return closed || (protocolConfig.ping() && !ping()); + return state.get().closed() || (protocolConfig.ping() && !ping()); } boolean ping() { @@ -203,13 +203,15 @@ void updateLastStreamId(int lastStreamId) { } void close() { - closed = true; - try { - handleTask.cancel(true); - ctx.log(LOGGER, TRACE, "Closing connection"); - connection.closeResource(); - } catch (Throwable e) { - ctx.log(LOGGER, TRACE, "Failed to close HTTP/2 connection.", e); + this.goAway(0, Http2ErrorCode.NO_ERROR, "Closing connection"); + if (state.getAndSet(State.CLOSED) != State.CLOSED) { + try { + handleTask.cancel(true); + ctx.log(LOGGER, TRACE, "Closing connection"); + connection.closeResource(); + } catch (Throwable e) { + ctx.log(LOGGER, TRACE, "Failed to close HTTP/2 connection.", e); + } } } @@ -268,14 +270,14 @@ private void start(Http2ClientProtocolConfig protocolConfig, try { while (!Thread.interrupted()) { if (!handle()) { - closed = true; + this.close(); ctx.log(LOGGER, TRACE, "Connection closed"); return; } } ctx.log(LOGGER, TRACE, "Client listener interrupted"); } catch (Throwable t) { - closed = true; + this.close(); ctx.log(LOGGER, DEBUG, "Failed to handle HTTP/2 client connection", t); } }); @@ -457,8 +459,26 @@ private void ackSettings() { } private void goAway(int streamId, Http2ErrorCode errorCode, String msg) { - Http2Settings http2Settings = Http2Settings.create(); - Http2GoAway frame = new Http2GoAway(streamId, errorCode, msg); - writer.write(frame.toFrameData(http2Settings, 0, Http2Flag.NoFlags.create())); + if (State.OPEN == state.getAndSet(State.GO_AWAY)) { + Http2Settings http2Settings = Http2Settings.create(); + Http2GoAway frame = new Http2GoAway(streamId, errorCode, msg); + writer.write(frame.toFrameData(http2Settings, 0, Http2Flag.NoFlags.create())); + } + } + + private enum State { + CLOSED(true), + GO_AWAY(true), + OPEN(false); + + private final boolean closed; + + State(boolean closed){ + this.closed = closed; + } + + boolean closed() { + return closed; + } } } diff --git a/webclient/http2/src/main/java/io/helidon/webclient/http2/Http2ClientConnectionHandler.java b/webclient/http2/src/main/java/io/helidon/webclient/http2/Http2ClientConnectionHandler.java index 5f09c4d390c..539b1a378f6 100644 --- a/webclient/http2/src/main/java/io/helidon/webclient/http2/Http2ClientConnectionHandler.java +++ b/webclient/http2/src/main/java/io/helidon/webclient/http2/Http2ClientConnectionHandler.java @@ -77,7 +77,10 @@ void close() { // this is to prevent concurrent modification (connections remove themselves from the map) Set toClose = new HashSet<>(allConnections.keySet()); toClose.forEach(Http2ClientConnection::close); - this.activeConnection.set(null); + Http2ClientConnection active = this.activeConnection.getAndSet(null); + if (active != null) { + active.close(); + } this.allConnections.clear(); } diff --git a/webclient/http2/src/main/java/io/helidon/webclient/http2/Http2ClientImpl.java b/webclient/http2/src/main/java/io/helidon/webclient/http2/Http2ClientImpl.java index 5f9c5222506..aa6a42d05bc 100644 --- a/webclient/http2/src/main/java/io/helidon/webclient/http2/Http2ClientImpl.java +++ b/webclient/http2/src/main/java/io/helidon/webclient/http2/Http2ClientImpl.java @@ -30,6 +30,7 @@ class Http2ClientImpl implements Http2Client, HttpClientSpi { private final Http2ClientConfig clientConfig; private final Http2ClientProtocolConfig protocolConfig; private final Http2ConnectionCache connectionCache; + private final Http2ConnectionCache clientCache; Http2ClientImpl(WebClient webClient, Http2ClientConfig clientConfig) { this.webClient = webClient; @@ -37,8 +38,10 @@ class Http2ClientImpl implements Http2Client, HttpClientSpi { this.protocolConfig = clientConfig.protocolConfig(); if (clientConfig.shareConnectionCache()) { this.connectionCache = Http2ConnectionCache.shared(); + this.clientCache = null; } else { this.connectionCache = Http2ConnectionCache.create(); + this.clientCache = connectionCache; } } @@ -94,6 +97,13 @@ public ClientRequest clientRequest(FullClientRequest clientRequest, Client .fragment(clientUri.fragment()); } + @Override + public void closeResource() { + if (clientCache != null) { + this.clientCache.closeResource(); + } + } + WebClient webClient() { return webClient; } diff --git a/webclient/http2/src/main/java/io/helidon/webclient/http2/Http2ConnectionCache.java b/webclient/http2/src/main/java/io/helidon/webclient/http2/Http2ConnectionCache.java index 487e8c2ba73..96decead36f 100644 --- a/webclient/http2/src/main/java/io/helidon/webclient/http2/Http2ConnectionCache.java +++ b/webclient/http2/src/main/java/io/helidon/webclient/http2/Http2ConnectionCache.java @@ -16,8 +16,10 @@ package io.helidon.webclient.http2; +import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; import io.helidon.common.configurable.LruCache; @@ -25,21 +27,34 @@ import io.helidon.webclient.api.ConnectionKey; import io.helidon.webclient.http1.Http1ClientRequest; import io.helidon.webclient.http1.Http1ClientResponse; +import io.helidon.webclient.spi.ClientConnectionCache; -final class Http2ConnectionCache { - //todo Gracefully close connections in channel cache - private static final Http2ConnectionCache SHARED = create(); +final class Http2ConnectionCache extends ClientConnectionCache { + private static final Http2ConnectionCache SHARED = new Http2ConnectionCache(true); private final LruCache http2Supported = LruCache.builder() .capacity(1000) .build(); private final Map cache = new ConcurrentHashMap<>(); + private final AtomicBoolean closed = new AtomicBoolean(); + + private Http2ConnectionCache(boolean shared) { + super(shared); + } static Http2ConnectionCache shared() { return SHARED; } static Http2ConnectionCache create() { - return new Http2ConnectionCache(); + return new Http2ConnectionCache(false); + } + + @Override + public void closeResource() { + if (!closed.getAndSet(true)) { + List.copyOf(cache.keySet()) + .forEach(this::closeAndRemove); + } } boolean supports(ConnectionKey ck) { @@ -47,11 +62,9 @@ boolean supports(ConnectionKey ck) { } void remove(ConnectionKey connectionKey) { - Http2ClientConnectionHandler handler = cache.remove(connectionKey); - if (handler != null) { - handler.close(); + if (!closed.get()) { + closeAndRemove(connectionKey); } - http2Supported.remove(connectionKey); } Http2ConnectionAttemptResult newStream(Http2ClientImpl http2Client, @@ -60,6 +73,10 @@ Http2ConnectionAttemptResult newStream(Http2ClientImpl http2Client, ClientUri initialUri, Function http1EntityHandler) { + if (closed.get()) { + throw new IllegalStateException("Connection cache is closed"); + } + // this statement locks all threads - must not do anything complicated (just create a new instance) Http2ConnectionAttemptResult result = cache.computeIfAbsent(connectionKey, Http2ClientConnectionHandler::new) @@ -73,4 +90,12 @@ Http2ConnectionAttemptResult newStream(Http2ClientImpl http2Client, } return result; } + + private void closeAndRemove(ConnectionKey connectionKey){ + Http2ClientConnectionHandler handler = cache.remove(connectionKey); + if (handler != null) { + handler.close(); + } + http2Supported.remove(connectionKey); + } } diff --git a/webclient/tests/http1/src/test/java/io/helidon/webclient/tests/SharedCacheTest.java b/webclient/tests/http1/src/test/java/io/helidon/webclient/tests/SharedCacheTest.java index f6a87d85a3f..2f222d30e75 100644 --- a/webclient/tests/http1/src/test/java/io/helidon/webclient/tests/SharedCacheTest.java +++ b/webclient/tests/http1/src/test/java/io/helidon/webclient/tests/SharedCacheTest.java @@ -29,6 +29,7 @@ import static io.helidon.http.Method.POST; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.is; +import static org.junit.jupiter.api.Assertions.assertThrows; class SharedCacheTest { @Test @@ -99,6 +100,55 @@ void cacheHttp1NoRestart() { assertThat(res.status(), is(Status.OK_200)); } + // with global connection cache is noop + webClient.closeResource(); + + Integer secondReqClientPort; + try (var res = webClient.post().submit("WHATEVER")) { + secondReqClientPort = res.headers().get(clientPortHeader).get(Integer.TYPE); + assertThat(res.status(), is(Status.OK_200)); + } + + assertThat("In case of cached connection client port must be the same.", + secondReqClientPort, + is(firstReqClientPort)); + } finally { + if (webServer != null) { + webServer.stop(); + } + } + } + + @Test + void clientCache() { + HeaderName clientPortHeader = HeaderNames.create("client-port"); + WebServer webServer = null; + try { + var routing = HttpRouting.builder() + .route(Http1Route.route(POST, "/", (req, res) -> { + res.header(clientPortHeader, String.valueOf(req.remotePeer().port())); + res.send(); + })); + + webServer = WebServer.builder() + .routing(routing) + .build() + .start(); + + int port = webServer.port(); + + WebClient webClient = WebClient.builder() + .shareConnectionCache(false) + .keepAlive(true) + .baseUri("http://localhost:" + port) + .build(); + + Integer firstReqClientPort; + try (var res = webClient.post().submit("WHATEVER")) { + firstReqClientPort = res.headers().get(clientPortHeader).get(Integer.TYPE); + assertThat(res.status(), is(Status.OK_200)); + } + Integer secondReqClientPort; try (var res = webClient.post().submit("WHATEVER")) { secondReqClientPort = res.headers().get(clientPortHeader).get(Integer.TYPE); @@ -114,4 +164,50 @@ void cacheHttp1NoRestart() { } } } + + @Test + void clientCacheClosed() { + HeaderName clientPortHeader = HeaderNames.create("client-port"); + WebServer webServer = null; + try { + var routing = HttpRouting.builder() + .route(Http1Route.route(POST, "/", (req, res) -> { + res.header(clientPortHeader, String.valueOf(req.remotePeer().port())); + res.send(); + })); + + webServer = WebServer.builder() + .routing(routing) + .build() + .start(); + + int port = webServer.port(); + + WebClient webClient = WebClient.builder() + .shareConnectionCache(false) + .keepAlive(true) + .baseUri("http://localhost:" + port) + .build(); + + try (var res = webClient.post().submit("WHATEVER")) { + res.headers().get(clientPortHeader).get(Integer.TYPE); + assertThat(res.status(), is(Status.OK_200)); + } + + webClient.closeResource(); + + IllegalStateException e = assertThrows(IllegalStateException.class, + () -> { + try (var res = webClient.post().submit("WHATEVER")) { + res.headers().get(clientPortHeader).get(Integer.TYPE); + } + }); + assertThat(e.getMessage(), is("Connection cache is closed")); + + } finally { + if (webServer != null) { + webServer.stop(); + } + } + } } diff --git a/webclient/tests/http2/pom.xml b/webclient/tests/http2/pom.xml index 40b7903b673..7e2549700bb 100644 --- a/webclient/tests/http2/pom.xml +++ b/webclient/tests/http2/pom.xml @@ -52,6 +52,11 @@ hamcrest-all test + + io.helidon.logging + helidon-logging-jul + test + io.vertx vertx-core diff --git a/webclient/tests/http2/src/test/java/io/helidon/webclient/tests/http2/MockHttp2Server.java b/webclient/tests/http2/src/test/java/io/helidon/webclient/tests/http2/MockHttp2Server.java new file mode 100644 index 00000000000..f5fffcb7530 --- /dev/null +++ b/webclient/tests/http2/src/test/java/io/helidon/webclient/tests/http2/MockHttp2Server.java @@ -0,0 +1,280 @@ +/* + * Copyright (c) 2023 Oracle and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.helidon.webclient.tests.http2; + +import java.net.InetSocketAddress; + +import io.helidon.common.Builder; + +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpServerCodec; +import io.netty.handler.codec.http.HttpServerUpgradeHandler; +import io.netty.handler.codec.http2.AbstractHttp2ConnectionHandlerBuilder; +import io.netty.handler.codec.http2.CleartextHttp2ServerUpgradeHandler; +import io.netty.handler.codec.http2.DefaultHttp2Headers; +import io.netty.handler.codec.http2.Http2ConnectionDecoder; +import io.netty.handler.codec.http2.Http2ConnectionEncoder; +import io.netty.handler.codec.http2.Http2ConnectionHandler; +import io.netty.handler.codec.http2.Http2Exception; +import io.netty.handler.codec.http2.Http2Flags; +import io.netty.handler.codec.http2.Http2FrameListener; +import io.netty.handler.codec.http2.Http2FrameLogger; +import io.netty.handler.codec.http2.Http2Headers; +import io.netty.handler.codec.http2.Http2ServerUpgradeCodec; +import io.netty.handler.codec.http2.Http2Settings; +import io.netty.handler.logging.LogLevel; +import io.netty.handler.logging.LoggingHandler; + +import static java.lang.System.Logger.Level.INFO; + +class MockHttp2Server { + static System.Logger LOGGER = System.getLogger(MockHttp2Server.class.getName()); + private final EventLoopGroup group; + private final InetSocketAddress socketAddress; + + public MockHttp2Server(EventLoopGroup group, InetSocketAddress socketAddress) { + this.group = group; + this.socketAddress = socketAddress; + } + + static MockHttp2ServerBuilder builder() { + return new MockHttp2ServerBuilder(); + } + + int port() { + return socketAddress.getPort(); + } + + void shutdown() { + group.shutdownGracefully(); + } + + @FunctionalInterface + interface Handler { + void handle(ChannelHandlerContext ctx, + int streamId, + Http2Headers headers, + ByteBuf payload, + Http2ConnectionEncoder encoder) throws Http2Exception; + } + + static class MockHttp2ServerBuilder implements Builder { + + private int port = 0; + + private Handler onHeadersHandler = (ctx, streamId, headers, unused, encoder) -> { + Http2Headers h = new DefaultHttp2Headers().status(HttpResponseStatus.OK.codeAsText()); + encoder.writeHeaders(ctx, streamId, h, 0, false, ctx.newPromise()); + encoder.writeData(ctx, + streamId, + Unpooled.wrappedBuffer("Hello World!".getBytes()), + 0, + true, + ctx.newPromise()); + }; + private Handler onGoAwayHandler = (ctx, streamId, headers, data,encoder) -> { + + }; + + @Override + public MockHttp2Server build() { + EventLoopGroup group = new NioEventLoopGroup(); + var initializer = + new ChannelInitializer<>() { + @Override + protected void initChannel(Channel channel) throws Exception { + HttpServerCodec codec = new HttpServerCodec( + 4096, + 16384, + 8192, + true, + 128); + + var mockHandler = new MockHttp2HandlerBuilder(MockHttp2ServerBuilder.this).build(); + + var upgradeHandler = + new HttpServerUpgradeHandler(codec, + protocol -> new Http2ServerUpgradeCodec(mockHandler), + 64 * 1024); + var cleartextHttp2ServerUpgradeHandler = + new CleartextHttp2ServerUpgradeHandler(codec, upgradeHandler, mockHandler); + channel.pipeline() + .addLast(cleartextHttp2ServerUpgradeHandler); + } + }; + ServerBootstrap b = new ServerBootstrap(); + b.option(ChannelOption.SO_BACKLOG, 1024); + b.group(group) + .channel(NioServerSocketChannel.class) + .handler(new LoggingHandler(LogLevel.INFO)) + .childHandler(initializer); + + Channel ch; + try { + ch = b.bind(port).sync().channel(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + InetSocketAddress socketAddress = (InetSocketAddress) ch.localAddress(); + + LOGGER.log(INFO, "HTTP/2 Server is listening on http://127.0.0.1:" + socketAddress.getPort() + '/'); + + return new MockHttp2Server(group, socketAddress); + } + + MockHttp2ServerBuilder port(int port) { + this.port = port; + return this; + } + + MockHttp2ServerBuilder onHeaders(Handler onHeadersHandler) { + this.onHeadersHandler = onHeadersHandler; + return this; + } + + MockHttp2ServerBuilder onGoAway(Handler handlerHandler) { + this.onGoAwayHandler = handlerHandler; + return this; + } + } + + static abstract class Http2Handler extends Http2ConnectionHandler implements Http2FrameListener { + protected Http2Handler(Http2ConnectionDecoder decoder, + Http2ConnectionEncoder encoder, + Http2Settings initialSettings) { + super(decoder, encoder, initialSettings); + } + + @Override + public int onDataRead(ChannelHandlerContext ctx, int streamId, ByteBuf data, int padding, boolean endOfStream) + throws Http2Exception { + + return 0; + } + + @Override + public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int padding, boolean endOfStream) + throws Http2Exception { + } + + @Override + public void onPriorityRead(ChannelHandlerContext ctx, int streamId, int streamDependency, short weight, boolean exclusive) + throws Http2Exception { + + } + + @Override + public void onRstStreamRead(ChannelHandlerContext ctx, int streamId, long errorCode) throws Http2Exception { + + } + + @Override + public void onSettingsAckRead(ChannelHandlerContext ctx) throws Http2Exception { + + } + + @Override + public void onSettingsRead(ChannelHandlerContext ctx, Http2Settings settings) throws Http2Exception { + + } + + @Override + public void onPingRead(ChannelHandlerContext ctx, long data) throws Http2Exception { + + } + + @Override + public void onPingAckRead(ChannelHandlerContext ctx, long data) throws Http2Exception { + + } + + @Override + public void onPushPromiseRead(ChannelHandlerContext ctx, + int streamId, + int promisedStreamId, + Http2Headers headers, + int padding) throws Http2Exception { + + } + + @Override + public void onWindowUpdateRead(ChannelHandlerContext ctx, int streamId, int windowSizeIncrement) throws Http2Exception { + } + + @Override + public void onUnknownFrame(ChannelHandlerContext ctx, byte frameType, int streamId, Http2Flags flags, ByteBuf payload) + throws Http2Exception { + + } + } + + static final class MockHttp2HandlerBuilder + extends AbstractHttp2ConnectionHandlerBuilder { + + private static final Http2FrameLogger LOGGER = new Http2FrameLogger(LogLevel.DEBUG, MockHttp2Server.class); + private final MockHttp2ServerBuilder serverBuilder; + + MockHttp2HandlerBuilder(MockHttp2ServerBuilder serverBuilder) { + this.serverBuilder = serverBuilder; + frameLogger(LOGGER); + } + + @Override + public Http2Handler build() { + return super.build(); + } + + @Override + protected Http2Handler build(Http2ConnectionDecoder decoder, Http2ConnectionEncoder encoder, + Http2Settings initialSettings) { + Http2Handler handler = new Http2Handler(decoder, encoder, initialSettings) { + + @Override + public void onHeadersRead(ChannelHandlerContext ctx, + int streamId, + Http2Headers headers, + int streamDependency, + short weight, + boolean exclusive, + int padding, + boolean endOfStream) throws Http2Exception { + serverBuilder.onHeadersHandler.handle(ctx, streamId, headers, null, encoder()); + } + + @Override + public void onGoAwayRead(ChannelHandlerContext ctx, int lastStreamId, long errorCode, ByteBuf debugData) + throws Http2Exception { + serverBuilder.onGoAwayHandler.handle(ctx, lastStreamId, null, debugData, encoder()); + } + }; + + frameListener(handler); + return handler; + } + + } +} diff --git a/webclient/tests/http2/src/test/java/io/helidon/webclient/tests/http2/ShutDownTest.java b/webclient/tests/http2/src/test/java/io/helidon/webclient/tests/http2/ShutDownTest.java new file mode 100644 index 00000000000..62dd1a6dbc5 --- /dev/null +++ b/webclient/tests/http2/src/test/java/io/helidon/webclient/tests/http2/ShutDownTest.java @@ -0,0 +1,139 @@ +/* + * Copyright (c) 2023 Oracle and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.helidon.webclient.tests.http2; + +import java.net.InetSocketAddress; +import java.time.Duration; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import io.helidon.logging.common.LogConfig; +import io.helidon.webclient.http2.Http2Client; +import io.helidon.webclient.http2.Http2ClientProtocolConfig; + +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http2.DefaultHttp2Headers; +import io.netty.handler.codec.http2.Http2Headers; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.fail; + +class ShutDownTest { + + private static final Duration TIMEOUT = Duration.ofSeconds(10); + private static MockHttp2Server mockHttp2Server; + private static int serverPort; + private static ConcurrentMap> goAwayReceived = new ConcurrentHashMap<>(); + + @BeforeAll + static void beforeAll() throws InterruptedException { + LogConfig.configureRuntime(); + mockHttp2Server = MockHttp2Server.builder() + .onGoAway((ctx, streamId, headers, payload, encoder) -> { + int remotePort = ((InetSocketAddress) ctx.channel().remoteAddress()).getPort(); + goAwayReceived.computeIfAbsent(remotePort, i -> new CompletableFuture<>()).complete(null); + }) + .onHeaders((ctx, streamId, headers, unused, encoder) -> { + Http2Headers h = new DefaultHttp2Headers() + .status(HttpResponseStatus.OK.codeAsText()); + + encoder.writeHeaders(ctx, streamId, h, 0, false, ctx.newPromise()); + + int remotePort = ((InetSocketAddress) ctx.channel().remoteAddress()).getPort(); + + encoder.writeData(ctx, + streamId, + Unpooled.wrappedBuffer(String.valueOf(remotePort).getBytes()), + 0, + true, + ctx.newPromise()); + }) + .build(); + serverPort = mockHttp2Server.port(); + } + + @AfterAll + static void afterAll() { + mockHttp2Server.shutdown(); + } + + @Test + void clientConnectionCacheHttp2() { + Http2Client + http2Client = Http2Client.builder() + .shareConnectionCache(false) + .protocolConfig(Http2ClientProtocolConfig.builder().priorKnowledge(true)) + .connectTimeout(Duration.ofMinutes(10)) + .baseUri("http://localhost:" + serverPort) + .build(); + + int clientPort; + try (var res = http2Client.get().request()) { + clientPort = Integer.parseInt(res.entity().as(String.class)); + } + + http2Client.closeResource(); + + try { + goAwayReceived.computeIfAbsent(clientPort, i -> new CompletableFuture<>()) + .get(TIMEOUT.getSeconds(), TimeUnit.SECONDS); + } catch (InterruptedException | ExecutionException | TimeoutException e) { + fail("GOAWAY not received from the client with port " + clientPort + "!", e); + } + + IllegalStateException e = assertThrows(IllegalStateException.class, + () -> { + try (var res = http2Client.get().request()) { + res.entity().as(String.class); + } + }); + assertThat(e.getMessage(), is("Connection cache is closed")); + } + + @Test + void globalConnectionCacheHttp2() { + Http2Client + http2Client = Http2Client.builder() + .shareConnectionCache(true) + .protocolConfig(Http2ClientProtocolConfig.builder().priorKnowledge(true)) + .connectTimeout(Duration.ofMinutes(10)) + .baseUri("http://localhost:" + serverPort) + .build(); + + String clientPort; + try (var res = http2Client.get().request()) { + clientPort = res.entity().as(String.class); + } + + // should be noop, not testing the shutdown hook + http2Client.closeResource(); + + try (var res = http2Client.get().request()) { + assertThat(res.entity().as(String.class), is(clientPort)); + } + } +} diff --git a/webserver/http2/src/main/java/io/helidon/webserver/http2/Http2ConfigBlueprint.java b/webserver/http2/src/main/java/io/helidon/webserver/http2/Http2ConfigBlueprint.java index 31e8ecb09e4..4cda86eed25 100644 --- a/webserver/http2/src/main/java/io/helidon/webserver/http2/Http2ConfigBlueprint.java +++ b/webserver/http2/src/main/java/io/helidon/webserver/http2/Http2ConfigBlueprint.java @@ -127,6 +127,14 @@ interface Http2ConfigBlueprint extends ProtocolConfig { @ConfiguredOption("100") int maxRapidResets(); + /** + * Maximum number of consecutive empty frames allowed on connection. + * + * @return max number of consecutive empty frames + */ + @ConfiguredOption("10") + int maxEmptyFrames(); + /** * If set to false, any path is accepted (even containing illegal characters). * diff --git a/webserver/http2/src/main/java/io/helidon/webserver/http2/Http2Connection.java b/webserver/http2/src/main/java/io/helidon/webserver/http2/Http2Connection.java index 172f1d4d8f9..3713d68f4e6 100644 --- a/webserver/http2/src/main/java/io/helidon/webserver/http2/Http2Connection.java +++ b/webserver/http2/src/main/java/io/helidon/webserver/http2/Http2Connection.java @@ -102,9 +102,11 @@ public class Http2Connection implements ServerConnection, InterruptableTask connectionHeaders; private final long rapidResetCheckPeriod; private final int maxRapidResets; + private final int maxEmptyFrames; private final long maxClientConcurrentStreams; private int rapidResetCnt = 0; private long rapidResetPeriodStart = 0; + private int emptyFrames = 0; // initial client settings, until we receive real ones private Http2Settings clientSettings = Http2Settings.builder() .build(); @@ -127,6 +129,7 @@ public class Http2Connection implements ServerConnection, InterruptableTask settingsUpdate(http2Config, builder)) .add(Http2Setting.ENABLE_PUSH, false) @@ -550,9 +553,12 @@ private void dataFrame() { StreamContext stream = stream(streamId); stream.stream().checkDataReceivable(); + boolean endOfStream = frameHeader.flags(Http2FrameTypes.DATA).endOfStream(); + // Flow-control: reading frameHeader.length() bytes from HTTP2 socket for known stream ID. int length = frameHeader.length(); if (length > 0) { + emptyFrames = 0; if (streamId > 0 && frameHeader.type() != Http2FrameType.HEADERS) { // Stream ID > 0: update connection and stream stream.stream() @@ -560,6 +566,10 @@ private void dataFrame() { .inbound() .decrementWindowSize(length); } + } else { + if (emptyFrames++ > maxEmptyFrames && !endOfStream) { + throw new Http2Exception(Http2ErrorCode.ENHANCE_YOUR_CALM, "Too much subsequent empty frames received."); + } } if (frameHeader.flags(Http2FrameTypes.DATA).padded()) { @@ -579,8 +589,6 @@ private void dataFrame() { buffer = inProgressFrame(); } - boolean endOfStream = frameHeader.flags(Http2FrameTypes.DATA).endOfStream(); - // TODO buffer now contains the actual data bytes stream.stream().data(frameHeader, buffer, endOfStream); @@ -758,7 +766,7 @@ private void rstStream() { rapidResetCnt = 1; rapidResetPeriodStart = currentTime; } else if (maxRapidResets < rapidResetCnt) { - throw new Http2Exception(Http2ErrorCode.PROTOCOL, "Rapid reset attack detected!"); + throw new Http2Exception(Http2ErrorCode.ENHANCE_YOUR_CALM, "Rapid reset attack detected!"); } else { rapidResetCnt++; } diff --git a/webserver/testing/junit5/junit5/src/main/java/io/helidon/webserver/testing/junit5/DirectClient.java b/webserver/testing/junit5/junit5/src/main/java/io/helidon/webserver/testing/junit5/DirectClient.java index aa4f75f21ac..6b1a3b95669 100644 --- a/webserver/testing/junit5/junit5/src/main/java/io/helidon/webserver/testing/junit5/DirectClient.java +++ b/webserver/testing/junit5/junit5/src/main/java/io/helidon/webserver/testing/junit5/DirectClient.java @@ -100,6 +100,11 @@ public Http1ClientRequest method(Method method) { .connection(new DirectClientConnection(socket, router)); } + @Override + public void closeResource() { + // Nothing to close in connection-less client + } + /** * Whether to use tls (mark this connection as secure). * diff --git a/webserver/testing/junit5/junit5/src/main/java/io/helidon/webserver/testing/junit5/DirectWebClient.java b/webserver/testing/junit5/junit5/src/main/java/io/helidon/webserver/testing/junit5/DirectWebClient.java index 05629de59d0..415804b3265 100644 --- a/webserver/testing/junit5/junit5/src/main/java/io/helidon/webserver/testing/junit5/DirectWebClient.java +++ b/webserver/testing/junit5/junit5/src/main/java/io/helidon/webserver/testing/junit5/DirectWebClient.java @@ -98,6 +98,11 @@ public HttpClientRequest method(Method method) { .connection(new DirectClientConnection(socket, router)); } + @Override + public void closeResource() { + // Nothing to close in connection-less client + } + @Override public WebClientConfig prototype() { return webClient.prototype(); diff --git a/webserver/tests/http2/src/test/java/io/helidon/webserver/tests/http2/EmptyFrameCntTest.java b/webserver/tests/http2/src/test/java/io/helidon/webserver/tests/http2/EmptyFrameCntTest.java new file mode 100644 index 00000000000..e76e284b410 --- /dev/null +++ b/webserver/tests/http2/src/test/java/io/helidon/webserver/tests/http2/EmptyFrameCntTest.java @@ -0,0 +1,167 @@ +/* + * Copyright (c) 2023 Oracle and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.helidon.webserver.tests.http2; + +import java.io.UncheckedIOException; +import java.net.SocketException; +import java.net.URI; +import java.time.Duration; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import io.helidon.common.buffers.BufferData; +import io.helidon.common.buffers.DataReader; +import io.helidon.common.tls.Tls; +import io.helidon.http.Method; +import io.helidon.http.WritableHeaders; +import io.helidon.http.http2.FlowControl; +import io.helidon.http.http2.Http2ConnectionWriter; +import io.helidon.http.http2.Http2DataFrame; +import io.helidon.http.http2.Http2Flag; +import io.helidon.http.http2.Http2FrameData; +import io.helidon.http.http2.Http2FrameHeader; +import io.helidon.http.http2.Http2FrameType; +import io.helidon.http.http2.Http2GoAway; +import io.helidon.http.http2.Http2Headers; +import io.helidon.http.http2.Http2Setting; +import io.helidon.http.http2.Http2Settings; +import io.helidon.http.http2.Http2Util; +import io.helidon.webclient.api.ClientUri; +import io.helidon.webclient.api.ConnectionKey; +import io.helidon.webclient.api.DefaultDnsResolver; +import io.helidon.webclient.api.Proxy; +import io.helidon.webclient.api.TcpClientConnection; +import io.helidon.webclient.api.WebClient; +import io.helidon.webserver.WebServer; +import io.helidon.webserver.WebServerConfig; +import io.helidon.webserver.http.HttpRouting; +import io.helidon.webserver.http2.Http2Config; +import io.helidon.webserver.http2.Http2Route; +import io.helidon.webserver.testing.junit5.ServerTest; +import io.helidon.webserver.testing.junit5.SetUpRoute; +import io.helidon.webserver.testing.junit5.SetUpServer; + +import org.junit.jupiter.api.Test; + +import static io.helidon.http.Method.GET; +import static io.netty.handler.codec.http2.Http2CodecUtil.FRAME_HEADER_LENGTH; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +@ServerTest +class EmptyFrameCntTest { + + private static final Duration TIMEOUT = Duration.ofSeconds(10); + private final WebServer server; + + public EmptyFrameCntTest(WebServer server) { + this.server = server; + } + + @SetUpRoute + static void router(HttpRouting.Builder router) { + router.route(Http2Route.route(GET, "/", (req, res) -> res.send("pong"))); + } + + @SetUpServer + static void setup(WebServerConfig.Builder server) { + server.addProtocol(Http2Config.builder().sendErrorDetails(true).build()); + } + + @Test + void emptyDataFramesAttack() throws InterruptedException, ExecutionException, TimeoutException { + ClientUri clientUri = ClientUri.create(URI.create("http://localhost:" + server.port())); + ConnectionKey connectionKey = new ConnectionKey(clientUri.scheme(), + clientUri.host(), + clientUri.port(), + Tls.builder().enabled(false).build(), + DefaultDnsResolver.create(), + null, + Proxy.noProxy()); + + TcpClientConnection conn = TcpClientConnection.create(WebClient.builder() + .baseUri(clientUri) + .build(), + connectionKey, + List.of(), + connection -> false, + connection -> { + }) + .connect(); + + BufferData prefaceData = Http2Util.prefaceData(); + conn.writer().writeNow(prefaceData); + Http2ConnectionWriter dataWriter = new Http2ConnectionWriter(conn.helidonSocket(), conn.writer(), List.of()); + + Http2Settings http2Settings = Http2Settings.builder() + .add(Http2Setting.INITIAL_WINDOW_SIZE, 65535L) + .add(Http2Setting.MAX_FRAME_SIZE, 16384L) + .add(Http2Setting.ENABLE_PUSH, false) + .build(); + Http2Flag.SettingsFlags flags = Http2Flag.SettingsFlags.create(0); + Http2FrameData frameData = http2Settings.toFrameData(null, 0, flags); + dataWriter.write(frameData); + + int streamId = 1; + + WritableHeaders headers = WritableHeaders.create(); + Http2Headers h2Headers = Http2Headers.create(headers); + h2Headers.method(Method.GET); + h2Headers.path(clientUri.path().path()); + h2Headers.scheme(clientUri.scheme()); + + dataWriter.writeHeaders(h2Headers, + streamId, + Http2Flag.HeaderFlags.create(Http2Flag.END_OF_HEADERS), + FlowControl.Outbound.NOOP); + + CompletableFuture gotGoAway = new CompletableFuture<>(); + + Thread.ofVirtual().start(() -> { + DataReader reader = conn.reader(); + for (; ; ) { + BufferData frameHeaderBuffer = reader.readBuffer(FRAME_HEADER_LENGTH); + Http2FrameHeader frameHeader = Http2FrameHeader.create(frameHeaderBuffer); + BufferData data = reader.readBuffer(frameHeader.length()); + if (frameHeader.type() == Http2FrameType.GO_AWAY) { + Http2GoAway http2GoAway = Http2GoAway.create(data); + gotGoAway.complete(http2GoAway.errorCode().name() + " - " + new String(data.readBytes())); + break; + } + } + }); + + for (int i = 0; i < 1000; i++) { + try { + Http2DataFrame emptyDataFrame = Http2DataFrame.create(BufferData.create()); + dataWriter.writeData(emptyDataFrame.toFrameData(http2Settings, streamId, Http2Flag.DataFlags.create(0)), + FlowControl.Outbound.NOOP); + + } catch (UncheckedIOException ex) { + assertThat(ex.getCause(), instanceOf(SocketException.class)); + } + } + String http2GoAway = gotGoAway.get(TIMEOUT.getSeconds(), TimeUnit.SECONDS); + assertThat(http2GoAway, is("ENHANCE_YOUR_CALM - Too much subsequent empty frames received.")); + + conn.closeResource(); + } +}