From 488cd2998643ff745610d6773df38809eaf8ca41 Mon Sep 17 00:00:00 2001 From: Tim Brooks Date: Tue, 6 Nov 2018 17:58:20 -0700 Subject: [PATCH] Open node connections asynchronously (#35144) This is related to #29023. Additionally at other points we have discussed a preference for removing the need to unnecessarily block threads for opening new node connections. This commit lays the groudwork for this by opening connections asynchronously at the transport level. We still block, however, this work will make it possible to eventually remove all blocking on new connections out of the TransportService and Transport. --- .../elasticsearch/transport/Netty4Plugin.java | 5 +- .../transport/netty4/Netty4TcpChannel.java | 27 +- .../transport/netty4/Netty4Transport.java | 34 +- .../netty4/Netty4ScheduledPingTests.java | 9 +- .../Netty4SizeHeaderFrameDecoderTests.java | 3 +- .../transport/netty4/Netty4TransportIT.java | 2 +- .../netty4/NettyTransportMultiPortTests.java | 3 +- .../netty4/SimpleNetty4TransportTests.java | 16 +- .../discovery/ec2/Ec2DiscoveryTests.java | 4 +- .../transport/nio/NioTcpChannel.java | 5 + .../transport/nio/NioTransport.java | 13 +- .../transport/nio/NioTransportPlugin.java | 5 +- .../transport/nio/NioTransportIT.java | 3 +- .../nio/SimpleNioTransportTests.java | 18 +- .../common/network/CloseableChannel.java | 3 +- .../elasticsearch/transport/TcpChannel.java | 51 +-- .../elasticsearch/transport/TcpTransport.java | 401 +++++++----------- .../transport/TcpTransportHandshaker.java | 185 ++++++++ .../TransportReplicationActionTests.java | 4 +- .../discovery/zen/UnicastZenPingTests.java | 15 +- .../TcpTransportHandshakerTests.java | 135 ++++++ .../transport/TcpTransportTests.java | 8 +- .../AbstractSimpleTransportTestCase.java | 11 +- .../transport/MockTcpTransport.java | 24 +- .../transport/nio/MockNioTransport.java | 16 +- .../transport/nio/MockNioTransportPlugin.java | 5 +- .../transport/MockTcpTransportTests.java | 10 +- .../nio/SimpleMockNioTransportTests.java | 19 +- .../xpack/core/XPackClientPlugin.java | 3 +- .../netty4/SecurityNetty4Transport.java | 4 +- .../xpack/security/Security.java | 4 +- .../netty4/SecurityNetty4ServerTransport.java | 4 +- .../transport/nio/SecurityNioTransport.java | 7 +- ...stractSimpleSecurityTransportTestCase.java | 10 +- .../SecurityNetty4ServerTransportTests.java | 2 + ...pleSecurityNetty4ServerTransportTests.java | 17 +- .../nio/SimpleSecurityNioTransportTests.java | 21 +- 37 files changed, 655 insertions(+), 451 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/transport/TcpTransportHandshaker.java create mode 100644 server/src/test/java/org/elasticsearch/transport/TcpTransportHandshakerTests.java diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/Netty4Plugin.java b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/Netty4Plugin.java index 70afcc86ad8f9..c2c841f889aff 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/Netty4Plugin.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/Netty4Plugin.java @@ -19,6 +19,7 @@ package org.elasticsearch.transport; +import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.network.NetworkModule; import org.elasticsearch.common.network.NetworkService; @@ -81,8 +82,8 @@ public Map> getTransports(Settings settings, ThreadP CircuitBreakerService circuitBreakerService, NamedWriteableRegistry namedWriteableRegistry, NetworkService networkService) { - return Collections.singletonMap(NETTY_TRANSPORT_NAME, () -> new Netty4Transport(settings, threadPool, networkService, bigArrays, - namedWriteableRegistry, circuitBreakerService)); + return Collections.singletonMap(NETTY_TRANSPORT_NAME, () -> new Netty4Transport(settings, Version.CURRENT, threadPool, + networkService, bigArrays, namedWriteableRegistry, circuitBreakerService)); } @Override diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4TcpChannel.java b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4TcpChannel.java index bee98362e0c1e..af66b7c79881a 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4TcpChannel.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4TcpChannel.java @@ -21,11 +21,15 @@ import io.netty.channel.Channel; import io.netty.channel.ChannelException; +import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelOption; import io.netty.channel.ChannelPromise; + import java.io.IOException; + import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.Nullable; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.concurrent.CompletableContext; import org.elasticsearch.transport.TcpChannel; @@ -37,11 +41,13 @@ public class Netty4TcpChannel implements TcpChannel { private final Channel channel; private final String profile; + private final CompletableContext connectContext; private final CompletableContext closeContext = new CompletableContext<>(); - Netty4TcpChannel(Channel channel, String profile) { + Netty4TcpChannel(Channel channel, String profile, @Nullable ChannelFuture connectFuture) { this.channel = channel; this.profile = profile; + this.connectContext = new CompletableContext<>(); this.channel.closeFuture().addListener(f -> { if (f.isSuccess()) { closeContext.complete(null); @@ -55,6 +61,20 @@ public class Netty4TcpChannel implements TcpChannel { } } }); + + connectFuture.addListener(f -> { + if (f.isSuccess()) { + connectContext.complete(null); + } else { + Throwable cause = f.cause(); + if (cause instanceof Error) { + ExceptionsHelper.maybeDieOnAnotherThread(cause); + connectContext.completeExceptionally(new Exception(cause)); + } else { + connectContext.completeExceptionally((Exception) cause); + } + } + }); } @Override @@ -72,6 +92,11 @@ public void addCloseListener(ActionListener listener) { closeContext.addListener(ActionListener.toBiConsumer(listener)); } + @Override + public void addConnectListener(ActionListener listener) { + connectContext.addListener(ActionListener.toBiConsumer(listener)); + } + @Override public void setSoLinger(int value) throws IOException { if (channel.isOpen()) { diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java index a4e5731cd6226..b34f50de0f041 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java @@ -38,7 +38,7 @@ import io.netty.util.concurrent.Future; import org.apache.logging.log4j.message.ParameterizedMessage; import org.elasticsearch.ExceptionsHelper; -import org.elasticsearch.action.ActionListener; +import org.elasticsearch.Version; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.SuppressForbidden; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; @@ -101,9 +101,9 @@ public class Netty4Transport extends TcpTransport { private volatile Bootstrap clientBootstrap; private volatile NioEventLoopGroup eventLoopGroup; - public Netty4Transport(Settings settings, ThreadPool threadPool, NetworkService networkService, BigArrays bigArrays, + public Netty4Transport(Settings settings, Version version, ThreadPool threadPool, NetworkService networkService, BigArrays bigArrays, NamedWriteableRegistry namedWriteableRegistry, CircuitBreakerService circuitBreakerService) { - super("netty", settings, threadPool, bigArrays, circuitBreakerService, namedWriteableRegistry, networkService); + super("netty", settings, version, threadPool, bigArrays, circuitBreakerService, namedWriteableRegistry, networkService); Netty4Utils.setAvailableProcessors(EsExecutors.PROCESSORS_SETTING.get(settings)); this.workerCount = WORKER_COUNT.get(settings); @@ -216,37 +216,23 @@ protected ChannelHandler getClientChannelInitializer(DiscoveryNode node) { static final AttributeKey SERVER_CHANNEL_KEY = AttributeKey.newInstance("es-server-channel"); @Override - protected Netty4TcpChannel initiateChannel(DiscoveryNode node, ActionListener listener) throws IOException { + protected Netty4TcpChannel initiateChannel(DiscoveryNode node) throws IOException { InetSocketAddress address = node.getAddress().address(); Bootstrap bootstrapWithHandler = clientBootstrap.clone(); bootstrapWithHandler.handler(getClientChannelInitializer(node)); bootstrapWithHandler.remoteAddress(address); - ChannelFuture channelFuture = bootstrapWithHandler.connect(); + ChannelFuture connectFuture = bootstrapWithHandler.connect(); - Channel channel = channelFuture.channel(); + Channel channel = connectFuture.channel(); if (channel == null) { - ExceptionsHelper.maybeDieOnAnotherThread(channelFuture.cause()); - throw new IOException(channelFuture.cause()); + ExceptionsHelper.maybeDieOnAnotherThread(connectFuture.cause()); + throw new IOException(connectFuture.cause()); } addClosedExceptionLogger(channel); - Netty4TcpChannel nettyChannel = new Netty4TcpChannel(channel, "default"); + Netty4TcpChannel nettyChannel = new Netty4TcpChannel(channel, "default", connectFuture); channel.attr(CHANNEL_KEY).set(nettyChannel); - channelFuture.addListener(f -> { - if (f.isSuccess()) { - listener.onResponse(null); - } else { - Throwable cause = f.cause(); - if (cause instanceof Error) { - ExceptionsHelper.maybeDieOnAnotherThread(cause); - listener.onFailure(new Exception(cause)); - } else { - listener.onFailure((Exception) cause); - } - } - }); - return nettyChannel; } @@ -309,7 +295,7 @@ protected ServerChannelInitializer(String name) { @Override protected void initChannel(Channel ch) throws Exception { addClosedExceptionLogger(ch); - Netty4TcpChannel nettyTcpChannel = new Netty4TcpChannel(ch, name); + Netty4TcpChannel nettyTcpChannel = new Netty4TcpChannel(ch, name, ch.newSucceededFuture()); ch.attr(CHANNEL_KEY).set(nettyTcpChannel); ch.pipeline().addLast("logging", new ESLoggingHandler()); ch.pipeline().addLast("size", new Netty4SizeHeaderFrameDecoder()); diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4ScheduledPingTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4ScheduledPingTests.java index 0f3185add0833..bae0cb7cef980 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4ScheduledPingTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4ScheduledPingTests.java @@ -18,6 +18,7 @@ */ package org.elasticsearch.transport.netty4; +import org.elasticsearch.Version; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; @@ -59,15 +60,15 @@ public void testScheduledPing() throws Exception { CircuitBreakerService circuitBreakerService = new NoneCircuitBreakerService(); NamedWriteableRegistry registry = new NamedWriteableRegistry(Collections.emptyList()); - final Netty4Transport nettyA = new Netty4Transport(settings, threadPool, new NetworkService(Collections.emptyList()), - BigArrays.NON_RECYCLING_INSTANCE, registry, circuitBreakerService); + final Netty4Transport nettyA = new Netty4Transport(settings, Version.CURRENT, threadPool, + new NetworkService(Collections.emptyList()), BigArrays.NON_RECYCLING_INSTANCE, registry, circuitBreakerService); MockTransportService serviceA = new MockTransportService(settings, nettyA, threadPool, TransportService.NOOP_TRANSPORT_INTERCEPTOR, null); serviceA.start(); serviceA.acceptIncomingRequests(); - final Netty4Transport nettyB = new Netty4Transport(settings, threadPool, new NetworkService(Collections.emptyList()), - BigArrays.NON_RECYCLING_INSTANCE, registry, circuitBreakerService); + final Netty4Transport nettyB = new Netty4Transport(settings, Version.CURRENT, threadPool, + new NetworkService(Collections.emptyList()), BigArrays.NON_RECYCLING_INSTANCE, registry, circuitBreakerService); MockTransportService serviceB = new MockTransportService(settings, nettyB, threadPool, TransportService.NOOP_TRANSPORT_INTERCEPTOR, null); diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4SizeHeaderFrameDecoderTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4SizeHeaderFrameDecoderTests.java index 564cf61a39569..a711bb690e366 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4SizeHeaderFrameDecoderTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4SizeHeaderFrameDecoderTests.java @@ -19,6 +19,7 @@ package org.elasticsearch.transport.netty4; +import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.network.NetworkService; import org.elasticsearch.common.settings.Settings; @@ -65,7 +66,7 @@ public void startThreadPool() { threadPool = new ThreadPool(settings); NetworkService networkService = new NetworkService(Collections.emptyList()); BigArrays bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService()); - nettyTransport = new Netty4Transport(settings, threadPool, networkService, bigArrays, + nettyTransport = new Netty4Transport(settings, Version.CURRENT, threadPool, networkService, bigArrays, new NamedWriteableRegistry(Collections.emptyList()), new NoneCircuitBreakerService()); nettyTransport.start(); diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4TransportIT.java b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4TransportIT.java index b81c8efcb47ee..b93e09b53649e 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4TransportIT.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4TransportIT.java @@ -108,7 +108,7 @@ public ExceptionThrowingNetty4Transport( BigArrays bigArrays, NamedWriteableRegistry namedWriteableRegistry, CircuitBreakerService circuitBreakerService) { - super(settings, threadPool, networkService, bigArrays, namedWriteableRegistry, circuitBreakerService); + super(settings, Version.CURRENT, threadPool, networkService, bigArrays, namedWriteableRegistry, circuitBreakerService); } @Override diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/NettyTransportMultiPortTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/NettyTransportMultiPortTests.java index a49df3caaba4e..785c4cfb114bc 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/NettyTransportMultiPortTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/NettyTransportMultiPortTests.java @@ -18,6 +18,7 @@ */ package org.elasticsearch.transport.netty4; +import org.elasticsearch.Version; import org.elasticsearch.common.component.Lifecycle; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.network.NetworkService; @@ -118,7 +119,7 @@ public void testThatDefaultProfilePortOverridesGeneralConfiguration() throws Exc private TcpTransport startTransport(Settings settings, ThreadPool threadPool) { BigArrays bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService()); - TcpTransport transport = new Netty4Transport(settings, threadPool, new NetworkService(Collections.emptyList()), + TcpTransport transport = new Netty4Transport(settings, Version.CURRENT, threadPool, new NetworkService(Collections.emptyList()), bigArrays, new NamedWriteableRegistry(Collections.emptyList()), new NoneCircuitBreakerService()); transport.start(); diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/SimpleNetty4TransportTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/SimpleNetty4TransportTests.java index e7faac8ae01db..4c651c31bee7e 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/SimpleNetty4TransportTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/SimpleNetty4TransportTests.java @@ -20,6 +20,7 @@ package org.elasticsearch.transport.netty4; import org.elasticsearch.Version; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.network.NetworkService; @@ -40,7 +41,6 @@ import org.elasticsearch.transport.Transport; import org.elasticsearch.transport.TransportService; -import java.io.IOException; import java.net.InetAddress; import java.net.UnknownHostException; import java.util.Collections; @@ -54,23 +54,17 @@ public class SimpleNetty4TransportTests extends AbstractSimpleTransportTestCase public static MockTransportService nettyFromThreadPool(Settings settings, ThreadPool threadPool, final Version version, ClusterSettings clusterSettings, boolean doHandshake) { NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList()); - Transport transport = new Netty4Transport(settings, threadPool, new NetworkService(Collections.emptyList()), + Transport transport = new Netty4Transport(settings, version, threadPool, new NetworkService(Collections.emptyList()), BigArrays.NON_RECYCLING_INSTANCE, namedWriteableRegistry, new NoneCircuitBreakerService()) { @Override - public Version executeHandshake(DiscoveryNode node, TcpChannel channel, TimeValue timeout) throws IOException, - InterruptedException { + public void executeHandshake(DiscoveryNode node, TcpChannel channel, TimeValue timeout, ActionListener listener) { if (doHandshake) { - return super.executeHandshake(node, channel, timeout); + super.executeHandshake(node, channel, timeout, listener); } else { - return version.minimumCompatibilityVersion(); + listener.onResponse(version.minimumCompatibilityVersion()); } } - - @Override - protected Version getCurrentVersion() { - return version; - } }; MockTransportService mockTransportService = MockTransportService.createNewService(settings, transport, version, threadPool, clusterSettings, Collections.emptySet()); diff --git a/plugins/discovery-ec2/src/test/java/org/elasticsearch/discovery/ec2/Ec2DiscoveryTests.java b/plugins/discovery-ec2/src/test/java/org/elasticsearch/discovery/ec2/Ec2DiscoveryTests.java index aa619409c16eb..98f2febd79516 100644 --- a/plugins/discovery-ec2/src/test/java/org/elasticsearch/discovery/ec2/Ec2DiscoveryTests.java +++ b/plugins/discovery-ec2/src/test/java/org/elasticsearch/discovery/ec2/Ec2DiscoveryTests.java @@ -20,7 +20,6 @@ package org.elasticsearch.discovery.ec2; import com.amazonaws.services.ec2.model.Tag; -import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.network.NetworkService; import org.elasticsearch.common.settings.Settings; @@ -74,8 +73,7 @@ public static void stopThreadPool() throws InterruptedException { public void createTransportService() { NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList()); final Transport transport = new MockTcpTransport(Settings.EMPTY, threadPool, BigArrays.NON_RECYCLING_INSTANCE, - new NoneCircuitBreakerService(), namedWriteableRegistry, new NetworkService(Collections.emptyList()), - Version.CURRENT) { + new NoneCircuitBreakerService(), namedWriteableRegistry, new NetworkService(Collections.emptyList())) { @Override public TransportAddress[] addressesFromString(String address, int perAddressLimit) throws UnknownHostException { // we just need to ensure we don't resolve DNS here diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTcpChannel.java b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTcpChannel.java index 947a255b178c8..480043acbd899 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTcpChannel.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTcpChannel.java @@ -58,6 +58,11 @@ public void addCloseListener(ActionListener listener) { addCloseListener(ActionListener.toBiConsumer(listener)); } + @Override + public void addConnectListener(ActionListener listener) { + addConnectListener(ActionListener.toBiConsumer(listener)); + } + @Override public void close() { getContext().closeChannel(); diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java index 15f7d1e28943f..ab1e1411c3b81 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java @@ -20,7 +20,7 @@ package org.elasticsearch.transport.nio; import org.elasticsearch.ElasticsearchException; -import org.elasticsearch.action.ActionListener; +import org.elasticsearch.Version; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.network.NetworkService; @@ -66,10 +66,10 @@ public class NioTransport extends TcpTransport { private volatile NioGroup nioGroup; private volatile TcpChannelFactory clientChannelFactory; - protected NioTransport(Settings settings, ThreadPool threadPool, NetworkService networkService, BigArrays bigArrays, - PageCacheRecycler pageCacheRecycler, NamedWriteableRegistry namedWriteableRegistry, - CircuitBreakerService circuitBreakerService) { - super("nio", settings, threadPool, bigArrays, circuitBreakerService, namedWriteableRegistry, networkService); + protected NioTransport(Settings settings, Version version, ThreadPool threadPool, NetworkService networkService, BigArrays bigArrays, + PageCacheRecycler pageCacheRecycler, NamedWriteableRegistry namedWriteableRegistry, + CircuitBreakerService circuitBreakerService) { + super("nio", settings, version, threadPool, bigArrays, circuitBreakerService, namedWriteableRegistry, networkService); this.pageCacheRecycler = pageCacheRecycler; } @@ -80,10 +80,9 @@ protected NioTcpServerChannel bind(String name, InetSocketAddress address) throw } @Override - protected NioTcpChannel initiateChannel(DiscoveryNode node, ActionListener connectListener) throws IOException { + protected NioTcpChannel initiateChannel(DiscoveryNode node) throws IOException { InetSocketAddress address = node.getAddress().address(); NioTcpChannel channel = nioGroup.openChannel(address, clientChannelFactory); - channel.addConnectListener(ActionListener.toBiConsumer(connectListener)); return channel; } diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransportPlugin.java b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransportPlugin.java index 1da8e909b2dd8..fd57ea20b1c8d 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransportPlugin.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransportPlugin.java @@ -19,6 +19,7 @@ package org.elasticsearch.transport.nio; +import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.network.NetworkService; import org.elasticsearch.common.settings.Setting; @@ -61,8 +62,8 @@ public Map> getTransports(Settings settings, ThreadP NamedWriteableRegistry namedWriteableRegistry, NetworkService networkService) { return Collections.singletonMap(NIO_TRANSPORT_NAME, - () -> new NioTransport(settings, threadPool, networkService, bigArrays, pageCacheRecycler, namedWriteableRegistry, - circuitBreakerService)); + () -> new NioTransport(settings, Version.CURRENT, threadPool, networkService, bigArrays, pageCacheRecycler, + namedWriteableRegistry, circuitBreakerService)); } @Override diff --git a/plugins/transport-nio/src/test/java/org/elasticsearch/transport/nio/NioTransportIT.java b/plugins/transport-nio/src/test/java/org/elasticsearch/transport/nio/NioTransportIT.java index df53a4d79c7ad..0c1bad79ee8e6 100644 --- a/plugins/transport-nio/src/test/java/org/elasticsearch/transport/nio/NioTransportIT.java +++ b/plugins/transport-nio/src/test/java/org/elasticsearch/transport/nio/NioTransportIT.java @@ -104,7 +104,8 @@ public Map> getTransports(Settings settings, ThreadP ExceptionThrowingNioTransport(Settings settings, ThreadPool threadPool, NetworkService networkService, BigArrays bigArrays, PageCacheRecycler pageCacheRecycler, NamedWriteableRegistry namedWriteableRegistry, CircuitBreakerService circuitBreakerService) { - super(settings, threadPool, networkService, bigArrays, pageCacheRecycler, namedWriteableRegistry, circuitBreakerService); + super(settings, Version.CURRENT, threadPool, networkService, bigArrays, pageCacheRecycler, namedWriteableRegistry, + circuitBreakerService); } @Override diff --git a/plugins/transport-nio/src/test/java/org/elasticsearch/transport/nio/SimpleNioTransportTests.java b/plugins/transport-nio/src/test/java/org/elasticsearch/transport/nio/SimpleNioTransportTests.java index 33d40b9f735fa..8fc1dd04dd7b6 100644 --- a/plugins/transport-nio/src/test/java/org/elasticsearch/transport/nio/SimpleNioTransportTests.java +++ b/plugins/transport-nio/src/test/java/org/elasticsearch/transport/nio/SimpleNioTransportTests.java @@ -20,6 +20,7 @@ package org.elasticsearch.transport.nio; import org.elasticsearch.Version; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.network.NetworkService; @@ -57,24 +58,17 @@ public static MockTransportService nioFromThreadPool(Settings settings, ThreadPo ClusterSettings clusterSettings, boolean doHandshake) { NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList()); NetworkService networkService = new NetworkService(Collections.emptyList()); - Transport transport = new NioTransport(settings, threadPool, - networkService, BigArrays.NON_RECYCLING_INSTANCE, new MockPageCacheRecycler(settings), namedWriteableRegistry, - new NoneCircuitBreakerService()) { + Transport transport = new NioTransport(settings, version, threadPool, networkService, BigArrays.NON_RECYCLING_INSTANCE, + new MockPageCacheRecycler(settings), namedWriteableRegistry, new NoneCircuitBreakerService()) { @Override - public Version executeHandshake(DiscoveryNode node, TcpChannel channel, TimeValue timeout) throws IOException, - InterruptedException { + public void executeHandshake(DiscoveryNode node, TcpChannel channel, TimeValue timeout, ActionListener listener) { if (doHandshake) { - return super.executeHandshake(node, channel, timeout); + super.executeHandshake(node, channel, timeout, listener); } else { - return version.minimumCompatibilityVersion(); + listener.onResponse(version.minimumCompatibilityVersion()); } } - - @Override - protected Version getCurrentVersion() { - return version; - } }; MockTransportService mockTransportService = MockTransportService.createNewService(settings, transport, version, threadPool, clusterSettings, Collections.emptySet()); diff --git a/server/src/main/java/org/elasticsearch/common/network/CloseableChannel.java b/server/src/main/java/org/elasticsearch/common/network/CloseableChannel.java index 6b89a90aa2c77..4fc3a0f6bb6bd 100644 --- a/server/src/main/java/org/elasticsearch/common/network/CloseableChannel.java +++ b/server/src/main/java/org/elasticsearch/common/network/CloseableChannel.java @@ -26,7 +26,6 @@ import java.io.Closeable; import java.io.IOException; -import java.io.UncheckedIOException; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -90,7 +89,7 @@ static void closeChannels(List channels, boolean IOUtils.close(channels); } catch (IOException e) { // The CloseableChannel#close method does not throw IOException, so this should not occur. - throw new UncheckedIOException(e); + throw new AssertionError(e); } if (blocking) { ArrayList> futures = new ArrayList<>(channels.size()); diff --git a/server/src/main/java/org/elasticsearch/transport/TcpChannel.java b/server/src/main/java/org/elasticsearch/transport/TcpChannel.java index bc5cc2c92f2cb..f4d265389d3d4 100644 --- a/server/src/main/java/org/elasticsearch/transport/TcpChannel.java +++ b/server/src/main/java/org/elasticsearch/transport/TcpChannel.java @@ -19,19 +19,12 @@ package org.elasticsearch.transport; -import org.elasticsearch.action.ActionFuture; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.network.CloseableChannel; -import org.elasticsearch.common.unit.TimeValue; import java.io.IOException; import java.net.InetSocketAddress; -import java.util.List; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; /** @@ -46,7 +39,6 @@ public interface TcpChannel extends CloseableChannel { */ String getProfile(); - /** * This sets the low level socket option {@link java.net.StandardSocketOptions} SO_LINGER on a channel. * @@ -55,7 +47,6 @@ public interface TcpChannel extends CloseableChannel { */ void setSoLinger(int value) throws IOException; - /** * Returns the local address for this channel. * @@ -80,42 +71,12 @@ public interface TcpChannel extends CloseableChannel { void sendMessage(BytesReference reference, ActionListener listener); /** - * Awaits for all of the pending connections to complete. Will throw an exception if at least one of the - * connections fails. + * Adds a listener that will be executed when the channel is connected. If the channel is still + * unconnected when this listener is added, the listener will be executed by the thread that eventually + * finishes the channel connection. If the channel is already connected when the listener is added the + * listener will immediately be executed by the thread that is attempting to add the listener. * - * @param discoveryNode the node for the pending connections - * @param connectionFutures representing the pending connections - * @param connectTimeout to wait for a connection - * @throws ConnectTransportException if one of the connections fails + * @param listener to be executed */ - static void awaitConnected(DiscoveryNode discoveryNode, List> connectionFutures, TimeValue connectTimeout) - throws ConnectTransportException { - Exception connectionException = null; - boolean allConnected = true; - - for (ActionFuture connectionFuture : connectionFutures) { - try { - connectionFuture.get(connectTimeout.getMillis(), TimeUnit.MILLISECONDS); - } catch (TimeoutException e) { - allConnected = false; - break; - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new IllegalStateException(e); - } catch (ExecutionException e) { - allConnected = false; - connectionException = (Exception) e.getCause(); - break; - } - } - - if (allConnected == false) { - if (connectionException == null) { - throw new ConnectTransportException(discoveryNode, "connect_timeout[" + connectTimeout + "]"); - } else { - throw new ConnectTransportException(discoveryNode, "connect_exception", connectionException); - } - } - } - + void addConnectListener(ActionListener listener); } diff --git a/server/src/main/java/org/elasticsearch/transport/TcpTransport.java b/server/src/main/java/org/elasticsearch/transport/TcpTransport.java index 46067930df110..eedd064bca7ca 100644 --- a/server/src/main/java/org/elasticsearch/transport/TcpTransport.java +++ b/server/src/main/java/org/elasticsearch/transport/TcpTransport.java @@ -23,7 +23,6 @@ import org.apache.logging.log4j.message.ParameterizedMessage; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.Version; -import org.elasticsearch.action.ActionFuture; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.NotifyOnceListener; import org.elasticsearch.action.support.PlainActionFuture; @@ -45,7 +44,6 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput; import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.metrics.CounterMetric; import org.elasticsearch.common.metrics.MeanMetric; import org.elasticsearch.common.network.CloseableChannel; @@ -61,6 +59,7 @@ import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.concurrent.AbstractRunnable; +import org.elasticsearch.common.util.concurrent.CountDown; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.internal.io.IOUtils; import org.elasticsearch.indices.breaker.CircuitBreakerService; @@ -88,7 +87,6 @@ import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.Optional; import java.util.Set; import java.util.TreeSet; import java.util.concurrent.ConcurrentHashMap; @@ -100,7 +98,6 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.locks.ReadWriteLock; import java.util.concurrent.locks.ReentrantReadWriteLock; -import java.util.function.Consumer; import java.util.function.Function; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -180,6 +177,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements protected final Settings settings; private final CircuitBreakerService circuitBreakerService; + private final Version version; protected final ThreadPool threadPool; private final BigArrays bigArrays; protected final NetworkService networkService; @@ -200,24 +198,22 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements private volatile BoundTransportAddress boundAddress; private final String transportName; - private final ConcurrentMap pendingHandshakes = new ConcurrentHashMap<>(); - private final CounterMetric numHandshakes = new CounterMetric(); - private static final String HANDSHAKE_ACTION_NAME = "internal:tcp/handshake"; - private final MeanMetric readBytesMetric = new MeanMetric(); private final MeanMetric transmittedBytesMetric = new MeanMetric(); private volatile Map> requestHandlers = Collections.emptyMap(); private final ResponseHandlers responseHandlers = new ResponseHandlers(); + private final TcpTransportHandshaker handshaker; private final TransportLogger transportLogger; private final BytesReference pingMessage; private final String nodeName; - public TcpTransport(String transportName, Settings settings, ThreadPool threadPool, BigArrays bigArrays, + public TcpTransport(String transportName, Settings settings, Version version, ThreadPool threadPool, BigArrays bigArrays, CircuitBreakerService circuitBreakerService, NamedWriteableRegistry namedWriteableRegistry, NetworkService networkService) { super(settings); this.settings = settings; this.profileSettings = getProfileSettings(settings); + this.version = version; this.threadPool = threadPool; this.bigArrays = bigArrays; this.circuitBreakerService = circuitBreakerService; @@ -226,6 +222,12 @@ public TcpTransport(String transportName, Settings settings, ThreadPool threadPo this.networkService = networkService; this.transportName = transportName; this.transportLogger = new TransportLogger(); + this.handshaker = new TcpTransportHandshaker(version, threadPool, + (node, channel, requestId, v) -> sendRequestToChannel(node, channel, requestId, + TcpTransportHandshaker.HANDSHAKE_ACTION_NAME, TransportRequest.Empty.INSTANCE, TransportRequestOptions.EMPTY, v, + TransportStatus.setHandshake((byte) 0)), + (v, features, channel, response, requestId) -> sendResponse(v, features, channel, response, requestId, + TcpTransportHandshaker.HANDSHAKE_ACTION_NAME, TransportResponseOptions.EMPTY, TransportStatus.setHandshake((byte) 0))); this.nodeName = Node.NODE_NAME_SETTING.get(settings); final Settings defaultFeatures = DEFAULT_FEATURES_SETTING.get(settings); @@ -277,41 +279,6 @@ public synchronized void registerRequestHandl requestHandlers = MapBuilder.newMapBuilder(requestHandlers).put(reg.getAction(), reg).immutableMap(); } - private static class HandshakeResponseHandler implements TransportResponseHandler { - final AtomicReference versionRef = new AtomicReference<>(); - final CountDownLatch latch = new CountDownLatch(1); - final AtomicReference exceptionRef = new AtomicReference<>(); - final TcpChannel channel; - - HandshakeResponseHandler(TcpChannel channel) { - this.channel = channel; - } - - @Override - public VersionHandshakeResponse read(StreamInput in) throws IOException { - return new VersionHandshakeResponse(in); - } - - @Override - public void handleResponse(VersionHandshakeResponse response) { - final boolean success = versionRef.compareAndSet(null, response.version); - latch.countDown(); - assert success; - } - - @Override - public void handleException(TransportException exp) { - final boolean success = exceptionRef.compareAndSet(null, exp); - latch.countDown(); - assert success; - } - - @Override - public String executor() { - return ThreadPool.Names.SAME; - } - } - public final class NodeChannels extends CloseableConnection { private final Map typeMapping; private final List channels; @@ -433,83 +400,59 @@ public NodeChannels openConnection(DiscoveryNode node, ConnectionProfile connect if (node == null) { throw new ConnectTransportException(null, "can't open connection to a null node"); } - boolean success = false; - NodeChannels nodeChannels = null; connectionProfile = maybeOverrideConnectionProfile(connectionProfile); closeLock.readLock().lock(); // ensure we don't open connections while we are closing try { ensureOpen(); - try { - int numConnections = connectionProfile.getNumConnections(); - assert numConnections > 0 : "A connection profile must be configured with at least one connection"; - List channels = new ArrayList<>(numConnections); - List> connectionFutures = new ArrayList<>(numConnections); - for (int i = 0; i < numConnections; ++i) { - try { - PlainActionFuture connectFuture = PlainActionFuture.newFuture(); - connectionFutures.add(connectFuture); - TcpChannel channel = initiateChannel(node, connectFuture); - logger.trace(() -> new ParameterizedMessage("Tcp transport client channel opened: {}", channel)); - channels.add(channel); - } catch (Exception e) { - // If there was an exception when attempting to instantiate the raw channels, we close all of the channels - CloseableChannel.closeChannels(channels, false); - throw e; - } - } + PlainActionFuture connectionFuture = PlainActionFuture.newFuture(); + List pendingChannels = initiateConnection(node, connectionProfile, connectionFuture); - // If we make it past the block above, we successfully instantiated all of the channels - try { - TcpChannel.awaitConnected(node, connectionFutures, connectionProfile.getConnectTimeout()); - } catch (Exception ex) { - CloseableChannel.closeChannels(channels, false); - throw ex; - } - - // If we make it past the block above, we have successfully established connections for all of the channels - final TcpChannel handshakeChannel = channels.get(0); // one channel is guaranteed by the connection profile - handshakeChannel.addCloseListener(ActionListener.wrap(() -> cancelHandshakeForChannel(handshakeChannel))); - Version version; - try { - version = executeHandshake(node, handshakeChannel, connectionProfile.getHandshakeTimeout()); - } catch (Exception ex) { - CloseableChannel.closeChannels(channels, false); - throw ex; + try { + return connectionFuture.actionGet(); + } catch (IllegalStateException e) { + // If the future was interrupted we can close the channels to improve the shutdown of the MockTcpTransport + if (e.getCause() instanceof InterruptedException) { + CloseableChannel.closeChannels(pendingChannels, false); } + throw e; + } + } finally { + closeLock.readLock().unlock(); + } + } - // If we make it past the block above, we have successfully completed the handshake and the connection is now open. - // At this point we should construct the connection, notify the transport service, and attach close listeners to the - // underlying channels. - nodeChannels = new NodeChannels(node, channels, connectionProfile, version); - final NodeChannels finalNodeChannels = nodeChannels; + private List initiateConnection(DiscoveryNode node, ConnectionProfile connectionProfile, + ActionListener listener) { + int numConnections = connectionProfile.getNumConnections(); + assert numConnections > 0 : "A connection profile must be configured with at least one connection"; - Consumer onClose = c -> { - assert c.isOpen() == false : "channel is still open when onClose is called"; - finalNodeChannels.close(); - }; + final List channels = new ArrayList<>(numConnections); - nodeChannels.channels.forEach(ch -> ch.addCloseListener(ActionListener.wrap(() -> onClose.accept(ch)))); - success = true; - return nodeChannels; + for (int i = 0; i < numConnections; ++i) { + try { + TcpChannel channel = initiateChannel(node); + logger.trace(() -> new ParameterizedMessage("Tcp transport client channel opened: {}", channel)); + channels.add(channel); } catch (ConnectTransportException e) { - throw e; + CloseableChannel.closeChannels(channels, false); + listener.onFailure(e); + return channels; } catch (Exception e) { - // ConnectTransportExceptions are handled specifically on the caller end - we wrap the actual exception to ensure - // only relevant exceptions are logged on the caller end.. this is the same as in connectToNode - throw new ConnectTransportException(node, "general node connection failure", e); - } finally { - if (success == false) { - IOUtils.closeWhileHandlingException(nodeChannels); - } + CloseableChannel.closeChannels(channels, false); + listener.onFailure(new ConnectTransportException(node, "general node connection failure", e)); + return channels; } - } finally { - closeLock.readLock().unlock(); } - } - protected Version getCurrentVersion() { - // this is just for tests to mock stuff like the nodes version - tests can override this internally - return Version.CURRENT; + ChannelsConnectedListener channelsConnectedListener = new ChannelsConnectedListener(node, connectionProfile, channels, listener); + + for (TcpChannel channel : channels) { + channel.addConnectListener(channelsConnectedListener); + } + + TimeValue connectTimeout = connectionProfile.getConnectTimeout(); + threadPool.schedule(connectTimeout, ThreadPool.Names.GENERIC, channelsConnectedListener::onTimeout); + return channels; } @Override @@ -677,7 +620,9 @@ public TransportAddress[] addressesFromString(String address, int perAddressLimi // not perfect, but PortsRange should take care of any port range validation, not a regex private static final Pattern BRACKET_PATTERN = Pattern.compile("^\\[(.*:.*)\\](?::([\\d\\-]*))?$"); - /** parse a hostname+port range spec into its equivalent addresses */ + /** + * parse a hostname+port range spec into its equivalent addresses + */ static TransportAddress[] parse(String hostPortString, String defaultPortRange, int perAddressLimit) throws UnknownHostException { Objects.requireNonNull(hostPortString); String host; @@ -775,7 +720,7 @@ public void onException(TcpChannel channel, Exception e) { if (isCloseConnectionException(e)) { logger.trace(() -> new ParameterizedMessage( - "close connection exception caught on transport layer [{}], disconnecting from relevant node", channel), e); + "close connection exception caught on transport layer [{}], disconnecting from relevant node", channel), e); // close the channel, which will cause a node to be disconnected if relevant CloseableChannel.closeChannel(channel); } else if (isConnectException(e)) { @@ -788,7 +733,7 @@ public void onException(TcpChannel channel, Exception e) { CloseableChannel.closeChannel(channel); } else if (e instanceof CancelledKeyException) { logger.trace(() -> new ParameterizedMessage( - "cancelled key exception caught on transport layer [{}], disconnecting from relevant node", channel), e); + "cancelled key exception caught on transport layer [{}], disconnecting from relevant node", channel), e); // close the channel as safe measure, which will cause a node to be disconnected if relevant CloseableChannel.closeChannel(channel); } else if (e instanceof TcpTransport.HttpOnTransportException) { @@ -856,11 +801,10 @@ protected void serverAcceptedChannel(TcpChannel channel) { * Initiate a single tcp socket channel. * * @param node for the initiated connection - * @param connectListener listener to be called when connection complete * @return the pending connection * @throws IOException if an I/O exception occurs while opening the channel */ - protected abstract TcpChannel initiateChannel(DiscoveryNode node, ActionListener connectListener) throws IOException; + protected abstract TcpChannel initiateChannel(DiscoveryNode node) throws IOException; /** * Called to tear down internal resources @@ -894,7 +838,7 @@ private void sendRequestToChannel(final DiscoveryNode node, final TcpChannel cha // we pick the smallest of the 2, to support both backward and forward compatibility // note, this is the only place we need to do this, since from here on, we use the serialized version // as the version to use also when the node receiving this request will send the response with - Version version = Version.min(getCurrentVersion(), channelVersion); + Version version = Version.min(this.version, channelVersion); stream.setVersion(version); threadPool.getThreadContext().writeTo(stream); @@ -941,12 +885,12 @@ private void internalSendMessage(TcpChannel channel, BytesReference message, Sen * @param action the action this response replies to */ public void sendErrorResponse( - final Version nodeVersion, - final Set features, - final TcpChannel channel, - final Exception error, - final long requestId, - final String action) throws IOException { + final Version nodeVersion, + final Set features, + final TcpChannel channel, + final Exception error, + final long requestId, + final String action) throws IOException { try (BytesStreamOutput stream = new BytesStreamOutput()) { stream.setVersion(nodeVersion); stream.setFeatures(features); @@ -972,25 +916,25 @@ public void sendErrorResponse( * @see #sendErrorResponse(Version, Set, TcpChannel, Exception, long, String) for sending back errors to the caller */ public void sendResponse( - final Version nodeVersion, - final Set features, - final TcpChannel channel, - final TransportResponse response, - final long requestId, - final String action, - final TransportResponseOptions options) throws IOException { + final Version nodeVersion, + final Set features, + final TcpChannel channel, + final TransportResponse response, + final long requestId, + final String action, + final TransportResponseOptions options) throws IOException { sendResponse(nodeVersion, features, channel, response, requestId, action, options, (byte) 0); } private void sendResponse( - final Version nodeVersion, - final Set features, - final TcpChannel channel, - final TransportResponse response, - final long requestId, - final String action, - TransportResponseOptions options, - byte status) throws IOException { + final Version nodeVersion, + final Set features, + final TcpChannel channel, + final TransportResponse response, + final long requestId, + final String action, + TransportResponseOptions options, + byte status) throws IOException { if (compress) { options = TransportResponseOptions.builder(options).withCompress(true).build(); } @@ -1087,13 +1031,13 @@ public void inboundMessage(TcpChannel channel, BytesReference message) { * Consumes bytes that are available from network reads. This method returns the number of bytes consumed * in this call. * - * @param channel the channel read from + * @param channel the channel read from * @param bytesReference the bytes available to consume * @return the number of bytes consumed - * @throws StreamCorruptedException if the message header format is not recognized + * @throws StreamCorruptedException if the message header format is not recognized * @throws TcpTransport.HttpOnTransportException if the message header appears to be an HTTP message - * @throws IllegalArgumentException if the message length is greater that the maximum allowed frame size. - * This is dependent on the available memory. + * @throws IllegalArgumentException if the message length is greater that the maximum allowed frame size. + * This is dependent on the available memory. */ public int consumeNetworkReads(TcpChannel channel, BytesReference bytesReference) throws IOException { BytesReference message = decodeFrame(bytesReference); @@ -1112,10 +1056,10 @@ public int consumeNetworkReads(TcpChannel channel, BytesReference bytesReference * * @param networkBytes the will be read * @return the message decoded - * @throws StreamCorruptedException if the message header format is not recognized + * @throws StreamCorruptedException if the message header format is not recognized * @throws TcpTransport.HttpOnTransportException if the message header appears to be an HTTP message - * @throws IllegalArgumentException if the message length is greater that the maximum allowed frame size. - * This is dependent on the available memory. + * @throws IllegalArgumentException if the message length is greater that the maximum allowed frame size. + * This is dependent on the available memory. */ static BytesReference decodeFrame(BytesReference networkBytes) throws IOException { int messageLength = readMessageLength(networkBytes); @@ -1139,10 +1083,10 @@ static BytesReference decodeFrame(BytesReference networkBytes) throws IOExceptio * * @param networkBytes the will be read * @return the length of the message - * @throws StreamCorruptedException if the message header format is not recognized + * @throws StreamCorruptedException if the message header format is not recognized * @throws TcpTransport.HttpOnTransportException if the message header appears to be an HTTP message - * @throws IllegalArgumentException if the message length is greater that the maximum allowed frame size. - * This is dependent on the available memory. + * @throws IllegalArgumentException if the message length is greater that the maximum allowed frame size. + * This is dependent on the available memory. */ public static int readMessageLength(BytesReference networkBytes) throws IOException { if (networkBytes.length() < BYTES_NEEDED_FOR_MESSAGE_SIZE) { @@ -1265,7 +1209,7 @@ public final void messageReceived(BytesReference reference, TcpChannel channel) streamIn = compressor.streamInput(streamIn); } final boolean isHandshake = TransportStatus.isHandshake(status); - ensureVersionCompatibility(version, getCurrentVersion(), isHandshake); + ensureVersionCompatibility(version, this.version, isHandshake); streamIn = new NamedWriteableAwareStreamInput(streamIn, namedWriteableRegistry); streamIn.setVersion(version); threadPool.getThreadContext().readHeaders(streamIn); @@ -1275,12 +1219,12 @@ public final void messageReceived(BytesReference reference, TcpChannel channel) } else { final TransportResponseHandler handler; if (isHandshake) { - handler = pendingHandshakes.remove(requestId); + handler = handshaker.removeHandlerForHandshake(requestId); } else { TransportResponseHandler theHandler = responseHandlers.onResponseReceived(requestId, messageListener); if (theHandler == null && TransportStatus.isError(status)) { - handler = pendingHandshakes.remove(requestId); + handler = handshaker.removeHandlerForHandshake(requestId); } else { handler = theHandler; } @@ -1325,7 +1269,7 @@ static void ensureVersionCompatibility(Version version, Version currentVersion, } private void handleResponse(InetSocketAddress remoteAddress, final StreamInput stream, - final TransportResponseHandler handler) { + final TransportResponseHandler handler) { final T response; try { response = handler.read(stream); @@ -1390,9 +1334,7 @@ protected String handleRequest(TcpChannel channel, String profileName, final Str TransportChannel transportChannel = null; try { if (TransportStatus.isHandshake(status)) { - final VersionHandshakeResponse response = new VersionHandshakeResponse(getCurrentVersion()); - sendResponse(version, features, channel, response, requestId, HANDSHAKE_ACTION_NAME, TransportResponseOptions.EMPTY, - TransportStatus.setHandshake((byte) 0)); + handshaker.handleHandshake(version, features, channel, requestId); } else { final RequestHandlerRegistry reg = getRequestHandler(action); if (reg == null) { @@ -1415,7 +1357,7 @@ protected String handleRequest(TcpChannel channel, String profileName, final Str // the circuit breaker tripped if (transportChannel == null) { transportChannel = - new TcpTransportChannel(this, channel, transportName, action, requestId, version, features, profileName, 0); + new TcpTransportChannel(this, channel, transportName, action, requestId, version, features, profileName, 0); } try { transportChannel.sendResponse(e); @@ -1468,100 +1410,22 @@ public void onFailure(Exception e) { } catch (Exception inner) { inner.addSuppressed(e); logger.warn(() -> new ParameterizedMessage( - "Failed to send error message back to client for action [{}]", reg.getAction()), inner); + "Failed to send error message back to client for action [{}]", reg.getAction()), inner); } } } } - private static final class VersionHandshakeResponse extends TransportResponse { - private final Version version; - - private VersionHandshakeResponse(Version version) { - this.version = version; - } - - private VersionHandshakeResponse(StreamInput in) throws IOException { - super.readFrom(in); - version = Version.readVersion(in); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - super.writeTo(out); - assert version != null; - Version.writeVersion(version, out); - } - } - - public Version executeHandshake(DiscoveryNode node, TcpChannel channel, TimeValue timeout) - throws IOException, InterruptedException { - numHandshakes.inc(); - final long requestId = responseHandlers.newRequestId(); - final HandshakeResponseHandler handler = new HandshakeResponseHandler(channel); - AtomicReference versionRef = handler.versionRef; - AtomicReference exceptionRef = handler.exceptionRef; - pendingHandshakes.put(requestId, handler); - boolean success = false; - try { - if (channel.isOpen() == false) { - // we have to protect us here since sendRequestToChannel won't barf if the channel is closed. - // it's weird but to change it will cause a lot of impact on the exception handling code all over the codebase. - // yet, if we don't check the state here we might have registered a pending handshake handler but the close - // listener calling #onChannelClosed might have already run and we are waiting on the latch below unitl we time out. - throw new IllegalStateException("handshake failed, channel already closed"); - } - // for the request we use the minCompatVersion since we don't know what's the version of the node we talk to - // we also have no payload on the request but the response will contain the actual version of the node we talk - // to as the payload. - final Version minCompatVersion = getCurrentVersion().minimumCompatibilityVersion(); - sendRequestToChannel(node, channel, requestId, HANDSHAKE_ACTION_NAME, TransportRequest.Empty.INSTANCE, - TransportRequestOptions.EMPTY, minCompatVersion, TransportStatus.setHandshake((byte) 0)); - if (handler.latch.await(timeout.millis(), TimeUnit.MILLISECONDS) == false) { - throw new ConnectTransportException(node, "handshake_timeout[" + timeout + "]"); - } - success = true; - if (exceptionRef.get() != null) { - throw new IllegalStateException("handshake failed", exceptionRef.get()); - } else { - Version version = versionRef.get(); - if (getCurrentVersion().isCompatible(version) == false) { - throw new IllegalStateException("Received message from unsupported version: [" + version - + "] minimal compatible version is: [" + getCurrentVersion().minimumCompatibilityVersion() + "]"); - } - return version; - } - } finally { - final TransportResponseHandler removedHandler = pendingHandshakes.remove(requestId); - // in the case of a timeout or an exception on the send part the handshake has not been removed yet. - // but the timeout is tricky since it's basically a race condition so we only assert on the success case. - assert success && removedHandler == null || success == false : "handler for requestId [" + requestId + "] is not been removed"; - } + public void executeHandshake(DiscoveryNode node, TcpChannel channel, TimeValue timeout, ActionListener listener) { + handshaker.sendHandshake(responseHandlers.newRequestId(), node, channel, timeout, listener); } - final int getNumPendingHandshakes() { // for testing - return pendingHandshakes.size(); + final int getNumPendingHandshakes() { + return handshaker.getNumPendingHandshakes(); } final long getNumHandshakes() { - return numHandshakes.count(); // for testing - } - - /** - * Called once the channel is closed for instance due to a disconnect or a closed socket etc. - */ - private void cancelHandshakeForChannel(TcpChannel channel) { - final Optional first = pendingHandshakes.entrySet().stream() - .filter((entry) -> entry.getValue().channel == channel).map(Map.Entry::getKey).findFirst(); - if (first.isPresent()) { - final Long requestId = first.get(); - final HandshakeResponseHandler handler = pendingHandshakes.remove(requestId); - if (handler != null) { - // there might be a race removing this or this method might be called twice concurrently depending on how - // the channel is closed ie. due to connection reset or broken pipes - handler.handleException(new TransportException("connection reset")); - } - } + return handshaker.getNumHandshakes(); } /** @@ -1741,4 +1605,69 @@ public final ResponseHandlers getResponseHandlers() { public final RequestHandlerRegistry getRequestHandler(String action) { return requestHandlers.get(action); } + + private final class ChannelsConnectedListener implements ActionListener { + + private final DiscoveryNode node; + private final ConnectionProfile connectionProfile; + private final List channels; + private final ActionListener listener; + private final CountDown countDown; + + private ChannelsConnectedListener(DiscoveryNode node, ConnectionProfile connectionProfile, List channels, + ActionListener listener) { + this.node = node; + this.connectionProfile = connectionProfile; + this.channels = channels; + this.listener = listener; + this.countDown = new CountDown(channels.size()); + } + + @Override + public void onResponse(Void v) { + // Returns true if all connections have completed successfully + if (countDown.countDown()) { + final TcpChannel handshakeChannel = channels.get(0); + try { + executeHandshake(node, handshakeChannel, connectionProfile.getHandshakeTimeout(), new ActionListener() { + @Override + public void onResponse(Version version) { + NodeChannels nodeChannels = new NodeChannels(node, channels, connectionProfile, version); + nodeChannels.channels.forEach(ch -> ch.addCloseListener(ActionListener.wrap(nodeChannels::close))); + listener.onResponse(nodeChannels); + } + + @Override + public void onFailure(Exception e) { + CloseableChannel.closeChannels(channels, false); + + if (e instanceof ConnectTransportException) { + listener.onFailure(e); + } else { + listener.onFailure(new ConnectTransportException(node, "general node connection failure", e)); + } + } + }); + } catch (Exception ex) { + CloseableChannel.closeChannels(channels, false); + listener.onFailure(ex); + } + } + } + + @Override + public void onFailure(Exception ex) { + if (countDown.fastForward()) { + CloseableChannel.closeChannels(channels, false); + listener.onFailure(new ConnectTransportException(node, "connect_exception", ex)); + } + } + + public void onTimeout() { + if (countDown.fastForward()) { + CloseableChannel.closeChannels(channels, false); + listener.onFailure(new ConnectTransportException(node, "connect_timeout[" + connectionProfile.getConnectTimeout() + "]")); + } + } + } } diff --git a/server/src/main/java/org/elasticsearch/transport/TcpTransportHandshaker.java b/server/src/main/java/org/elasticsearch/transport/TcpTransportHandshaker.java new file mode 100644 index 0000000000000..d1037d2bcb5bd --- /dev/null +++ b/server/src/main/java/org/elasticsearch/transport/TcpTransportHandshaker.java @@ -0,0 +1,185 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.transport; + +import org.elasticsearch.Version; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.metrics.CounterMetric; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.threadpool.ThreadPool; + +import java.io.IOException; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * Sends and receives transport-level connection handshakes. This class will send the initial handshake, + * manage state/timeouts while the handshake is in transit, and handle the eventual response. + */ +final class TcpTransportHandshaker { + + static final String HANDSHAKE_ACTION_NAME = "internal:tcp/handshake"; + private final ConcurrentMap pendingHandshakes = new ConcurrentHashMap<>(); + private final CounterMetric numHandshakes = new CounterMetric(); + + private final Version version; + private final ThreadPool threadPool; + private final HandshakeRequestSender handshakeRequestSender; + private final HandshakeResponseSender handshakeResponseSender; + + TcpTransportHandshaker(Version version, ThreadPool threadPool, HandshakeRequestSender handshakeRequestSender, + HandshakeResponseSender handshakeResponseSender) { + this.version = version; + this.threadPool = threadPool; + this.handshakeRequestSender = handshakeRequestSender; + this.handshakeResponseSender = handshakeResponseSender; + } + + void sendHandshake(long requestId, DiscoveryNode node, TcpChannel channel, TimeValue timeout, ActionListener listener) { + numHandshakes.inc(); + final HandshakeResponseHandler handler = new HandshakeResponseHandler(requestId, version, listener); + pendingHandshakes.put(requestId, handler); + channel.addCloseListener(ActionListener.wrap( + () -> handler.handleLocalException(new TransportException("handshake failed because connection reset")))); + boolean success = false; + try { + // for the request we use the minCompatVersion since we don't know what's the version of the node we talk to + // we also have no payload on the request but the response will contain the actual version of the node we talk + // to as the payload. + final Version minCompatVersion = version.minimumCompatibilityVersion(); + handshakeRequestSender.sendRequest(node, channel, requestId, minCompatVersion); + + threadPool.schedule(timeout, ThreadPool.Names.GENERIC, + () -> handler.handleLocalException(new ConnectTransportException(node, "handshake_timeout[" + timeout + "]"))); + success = true; + } catch (Exception e) { + handler.handleLocalException(new ConnectTransportException(node, "failure to send " + HANDSHAKE_ACTION_NAME, e)); + } finally { + if (success == false) { + TransportResponseHandler removed = pendingHandshakes.remove(requestId); + assert removed == null : "Handshake should not be pending if exception was thrown"; + } + } + } + + void handleHandshake(Version version, Set features, TcpChannel channel, long requestId) throws IOException { + handshakeResponseSender.sendResponse(version, features, channel, new VersionHandshakeResponse(this.version), requestId); + } + + TransportResponseHandler removeHandlerForHandshake(long requestId) { + return pendingHandshakes.remove(requestId); + } + + int getNumPendingHandshakes() { + return pendingHandshakes.size(); + } + + long getNumHandshakes() { + return numHandshakes.count(); + } + + private class HandshakeResponseHandler implements TransportResponseHandler { + + private final long requestId; + private final Version currentVersion; + private final ActionListener listener; + private final AtomicBoolean isDone = new AtomicBoolean(false); + + private HandshakeResponseHandler(long requestId, Version currentVersion, ActionListener listener) { + this.requestId = requestId; + this.currentVersion = currentVersion; + this.listener = listener; + } + + @Override + public VersionHandshakeResponse read(StreamInput in) throws IOException { + return new VersionHandshakeResponse(in); + } + + @Override + public void handleResponse(VersionHandshakeResponse response) { + if (isDone.compareAndSet(false, true)) { + Version version = response.version; + if (currentVersion.isCompatible(version) == false) { + listener.onFailure(new IllegalStateException("Received message from unsupported version: [" + version + + "] minimal compatible version is: [" + currentVersion.minimumCompatibilityVersion() + "]")); + } else { + listener.onResponse(version); + } + } + } + + @Override + public void handleException(TransportException e) { + if (isDone.compareAndSet(false, true)) { + listener.onFailure(new IllegalStateException("handshake failed", e)); + } + } + + void handleLocalException(TransportException e) { + if (removeHandlerForHandshake(requestId) != null && isDone.compareAndSet(false, true)) { + listener.onFailure(e); + } + } + + @Override + public String executor() { + return ThreadPool.Names.SAME; + } + } + + static final class VersionHandshakeResponse extends TransportResponse { + + private final Version version; + + VersionHandshakeResponse(Version version) { + this.version = version; + } + + private VersionHandshakeResponse(StreamInput in) throws IOException { + super.readFrom(in); + version = Version.readVersion(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + assert version != null; + Version.writeVersion(version, out); + } + } + + @FunctionalInterface + interface HandshakeRequestSender { + + void sendRequest(DiscoveryNode node, TcpChannel channel, long requestId, Version version) throws IOException; + } + + @FunctionalInterface + interface HandshakeResponseSender { + + void sendResponse(Version version, Set features, TcpChannel channel, TransportResponse response, long requestId) + throws IOException; + } +} diff --git a/server/src/test/java/org/elasticsearch/action/support/replication/TransportReplicationActionTests.java b/server/src/test/java/org/elasticsearch/action/support/replication/TransportReplicationActionTests.java index 0469fac4d7d55..cbe490c87aec4 100644 --- a/server/src/test/java/org/elasticsearch/action/support/replication/TransportReplicationActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/support/replication/TransportReplicationActionTests.java @@ -21,7 +21,6 @@ import org.apache.lucene.store.AlreadyClosedException; import org.elasticsearch.ElasticsearchException; -import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.UnavailableShardsException; import org.elasticsearch.action.admin.indices.close.CloseIndexRequest; @@ -981,8 +980,7 @@ public void testRetryOnReplicaWithRealTransport() throws Exception { final ReplicationTask task = maybeTask(); NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList()); final Transport transport = new MockTcpTransport(Settings.EMPTY, threadPool, BigArrays.NON_RECYCLING_INSTANCE, - new NoneCircuitBreakerService(), namedWriteableRegistry, new NetworkService(Collections.emptyList()), - Version.CURRENT); + new NoneCircuitBreakerService(), namedWriteableRegistry, new NetworkService(Collections.emptyList())); transportService = new MockTransportService(Settings.EMPTY, transport, threadPool, TransportService.NOOP_TRANSPORT_INTERCEPTOR, x -> clusterService.localNode(), null, Collections.emptySet()); transportService.start(); diff --git a/server/src/test/java/org/elasticsearch/discovery/zen/UnicastZenPingTests.java b/server/src/test/java/org/elasticsearch/discovery/zen/UnicastZenPingTests.java index 4ab738f5c7bc3..ed310ee305acf 100644 --- a/server/src/test/java/org/elasticsearch/discovery/zen/UnicastZenPingTests.java +++ b/server/src/test/java/org/elasticsearch/discovery/zen/UnicastZenPingTests.java @@ -377,8 +377,7 @@ public void testPortLimit() throws InterruptedException { BigArrays.NON_RECYCLING_INSTANCE, new NoneCircuitBreakerService(), new NamedWriteableRegistry(Collections.emptyList()), - networkService, - Version.CURRENT) { + networkService) { @Override public BoundTransportAddress boundAddress() { @@ -419,8 +418,7 @@ public void testRemovingLocalAddresses() throws InterruptedException { BigArrays.NON_RECYCLING_INSTANCE, new NoneCircuitBreakerService(), new NamedWriteableRegistry(Collections.emptyList()), - networkService, - Version.CURRENT) { + networkService) { @Override public BoundTransportAddress boundAddress() { @@ -465,8 +463,7 @@ public void testUnknownHost() throws InterruptedException { BigArrays.NON_RECYCLING_INSTANCE, new NoneCircuitBreakerService(), new NamedWriteableRegistry(Collections.emptyList()), - networkService, - Version.CURRENT) { + networkService) { @Override public BoundTransportAddress boundAddress() { @@ -512,8 +509,7 @@ public void testResolveTimeout() throws InterruptedException { BigArrays.NON_RECYCLING_INSTANCE, new NoneCircuitBreakerService(), new NamedWriteableRegistry(Collections.emptyList()), - networkService, - Version.CURRENT) { + networkService) { @Override public BoundTransportAddress boundAddress() { @@ -578,8 +574,7 @@ public void testResolveReuseExistingNodeConnections() throws ExecutionException, BigArrays.NON_RECYCLING_INSTANCE, new NoneCircuitBreakerService(), new NamedWriteableRegistry(Collections.emptyList()), - networkService, - v); + networkService); NetworkHandle handleA = startServices(settings, threadPool, "UZP_A", Version.CURRENT, supplier, EnumSet.allOf(Role.class)); closeables.push(handleA.transportService); diff --git a/server/src/test/java/org/elasticsearch/transport/TcpTransportHandshakerTests.java b/server/src/test/java/org/elasticsearch/transport/TcpTransportHandshakerTests.java new file mode 100644 index 0000000000000..23e3870842e20 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/transport/TcpTransportHandshakerTests.java @@ -0,0 +1,135 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.transport; + +import org.elasticsearch.Version; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.mockito.ArgumentCaptor; + +import java.io.IOException; +import java.util.Collections; +import java.util.concurrent.TimeUnit; + +import static org.hamcrest.Matchers.containsString; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +public class TcpTransportHandshakerTests extends ESTestCase { + + private TcpTransportHandshaker handshaker; + private DiscoveryNode node; + private TcpChannel channel; + private TestThreadPool threadPool; + private TcpTransportHandshaker.HandshakeRequestSender requestSender; + private TcpTransportHandshaker.HandshakeResponseSender responseSender; + + @Override + public void setUp() throws Exception { + super.setUp(); + String nodeId = "node-id"; + channel = mock(TcpChannel.class); + requestSender = mock(TcpTransportHandshaker.HandshakeRequestSender.class); + responseSender = mock(TcpTransportHandshaker.HandshakeResponseSender.class); + node = new DiscoveryNode(nodeId, nodeId, nodeId, "host", "host_address", buildNewFakeTransportAddress(), Collections.emptyMap(), + Collections.emptySet(), Version.CURRENT); + threadPool = new TestThreadPool("thread-poll"); + handshaker = new TcpTransportHandshaker(Version.CURRENT, threadPool, requestSender, responseSender); + } + + @Override + public void tearDown() throws Exception { + threadPool.shutdown(); + super.tearDown(); + } + + public void testHandshakeRequestAndResponse() throws IOException { + PlainActionFuture versionFuture = PlainActionFuture.newFuture(); + long reqId = randomLongBetween(1, 10); + handshaker.sendHandshake(reqId, node, channel, new TimeValue(30, TimeUnit.SECONDS), versionFuture); + + verify(requestSender).sendRequest(node, channel, reqId, Version.CURRENT.minimumCompatibilityVersion()); + + assertFalse(versionFuture.isDone()); + + TcpChannel mockChannel = mock(TcpChannel.class); + handshaker.handleHandshake(Version.CURRENT, Collections.emptySet(), mockChannel, reqId); + + + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(TransportResponse.class); + verify(responseSender).sendResponse(eq(Version.CURRENT), eq(Collections.emptySet()), eq(mockChannel), responseCaptor.capture(), + eq(reqId)); + + TransportResponseHandler handler = handshaker.removeHandlerForHandshake(reqId); + handler.handleResponse((TcpTransportHandshaker.VersionHandshakeResponse) responseCaptor.getValue()); + + assertTrue(versionFuture.isDone()); + assertEquals(Version.CURRENT, versionFuture.actionGet()); + } + + public void testHandshakeError() throws IOException { + PlainActionFuture versionFuture = PlainActionFuture.newFuture(); + long reqId = randomLongBetween(1, 10); + handshaker.sendHandshake(reqId, node, channel, new TimeValue(30, TimeUnit.SECONDS), versionFuture); + + verify(requestSender).sendRequest(node, channel, reqId, Version.CURRENT.minimumCompatibilityVersion()); + + assertFalse(versionFuture.isDone()); + + TransportResponseHandler handler = handshaker.removeHandlerForHandshake(reqId); + handler.handleException(new TransportException("failed")); + + assertTrue(versionFuture.isDone()); + IllegalStateException ise = expectThrows(IllegalStateException.class, versionFuture::actionGet); + assertThat(ise.getMessage(), containsString("handshake failed")); + } + + public void testSendRequestThrowsException() throws IOException { + PlainActionFuture versionFuture = PlainActionFuture.newFuture(); + long reqId = randomLongBetween(1, 10); + Version compatibilityVersion = Version.CURRENT.minimumCompatibilityVersion(); + doThrow(new IOException("boom")).when(requestSender).sendRequest(node, channel, reqId, compatibilityVersion); + + handshaker.sendHandshake(reqId, node, channel, new TimeValue(30, TimeUnit.SECONDS), versionFuture); + + + assertTrue(versionFuture.isDone()); + ConnectTransportException cte = expectThrows(ConnectTransportException.class, versionFuture::actionGet); + assertThat(cte.getMessage(), containsString("failure to send internal:tcp/handshake")); + assertNull(handshaker.removeHandlerForHandshake(reqId)); + } + + public void testHandshakeTimeout() throws IOException { + PlainActionFuture versionFuture = PlainActionFuture.newFuture(); + long reqId = randomLongBetween(1, 10); + handshaker.sendHandshake(reqId, node, channel, new TimeValue(100, TimeUnit.MILLISECONDS), versionFuture); + + verify(requestSender).sendRequest(node, channel, reqId, Version.CURRENT.minimumCompatibilityVersion()); + + ConnectTransportException cte = expectThrows(ConnectTransportException.class, versionFuture::actionGet); + assertThat(cte.getMessage(), containsString("handshake_timeout")); + + assertNull(handshaker.removeHandlerForHandshake(reqId)); + } +} diff --git a/server/src/test/java/org/elasticsearch/transport/TcpTransportTests.java b/server/src/test/java/org/elasticsearch/transport/TcpTransportTests.java index a17103789f251..b9ce7d3be3700 100644 --- a/server/src/test/java/org/elasticsearch/transport/TcpTransportTests.java +++ b/server/src/test/java/org/elasticsearch/transport/TcpTransportTests.java @@ -188,7 +188,7 @@ public void testCompressRequest() throws IOException { AtomicReference messageCaptor = new AtomicReference<>(); try { TcpTransport transport = new TcpTransport( - "test", Settings.builder().put("transport.tcp.compress", compressed).build(), threadPool, + "test", Settings.builder().put("transport.tcp.compress", compressed).build(), Version.CURRENT, threadPool, new BigArrays(new PageCacheRecycler(Settings.EMPTY), null), null, null, null) { @Override @@ -197,7 +197,7 @@ protected FakeChannel bind(String name, InetSocketAddress address) throws IOExce } @Override - protected FakeChannel initiateChannel(DiscoveryNode node, ActionListener connectListener) throws IOException { + protected FakeChannel initiateChannel(DiscoveryNode node) throws IOException { return new FakeChannel(messageCaptor); } @@ -271,6 +271,10 @@ public String getProfile() { public void addCloseListener(ActionListener listener) { } + @Override + public void addConnectListener(ActionListener listener) { + } + @Override public void setSoLinger(int value) throws IOException { } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java index f4cf6e09642de..2015bbf353de0 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java @@ -199,9 +199,7 @@ public void tearDown() throws Exception { assertNoPendingHandshakes(serviceA.getOriginalTransport()); assertNoPendingHandshakes(serviceB.getOriginalTransport()); } finally { - IOUtils.close(serviceA, serviceB, () -> { - terminate(threadPool); - }); + IOUtils.close(serviceA, serviceB, () -> terminate(threadPool)); } } @@ -2030,9 +2028,10 @@ protected String handleRequest(TcpChannel mockChannel, String profileName, Strea TcpTransport.NodeChannels connection = originalTransport.openConnection( new DiscoveryNode("TS_TPC", "TS_TPC", service.boundAddress().publishAddress(), emptyMap(), emptySet(), version0), connectionProfile)) { - Version version = originalTransport.executeHandshake(connection.getNode(), - connection.channel(TransportRequestOptions.Type.PING), TimeValue.timeValueSeconds(10)); - assertEquals(version, Version.CURRENT); + PlainActionFuture listener = PlainActionFuture.newFuture(); + originalTransport.executeHandshake(connection.getNode(), connection.channel(TransportRequestOptions.Type.PING), + TimeValue.timeValueSeconds(10), listener); + assertEquals(listener.actionGet(), Version.CURRENT); } } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java b/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java index 99aa540b68411..2fddb42d57034 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java @@ -88,7 +88,6 @@ public class MockTcpTransport extends TcpTransport { } private final ExecutorService executor; - private final Version mockVersion; public MockTcpTransport(Settings settings, ThreadPool threadPool, BigArrays bigArrays, CircuitBreakerService circuitBreakerService, NamedWriteableRegistry namedWriteableRegistry, @@ -100,11 +99,11 @@ public MockTcpTransport(Settings settings, ThreadPool threadPool, BigArrays bigA public MockTcpTransport(Settings settings, ThreadPool threadPool, BigArrays bigArrays, CircuitBreakerService circuitBreakerService, NamedWriteableRegistry namedWriteableRegistry, NetworkService networkService, Version mockVersion) { - super("mock-tcp-transport", settings, threadPool, bigArrays, circuitBreakerService, namedWriteableRegistry, networkService); + super("mock-tcp-transport", settings, mockVersion, threadPool, bigArrays, circuitBreakerService, namedWriteableRegistry, + networkService); // we have our own crazy cached threadpool this one is not bounded at all... // using the ES thread factory here is crucial for tests otherwise disruption tests won't block that thread executor = Executors.newCachedThreadPool(EsExecutors.daemonThreadFactory(settings, Transports.TEST_MOCK_TRANSPORT_THREAD_PREFIX)); - this.mockVersion = mockVersion; } @Override @@ -163,7 +162,7 @@ private void readMessage(MockChannel mockChannel, StreamInput input) throws IOEx @Override @SuppressForbidden(reason = "real socket for mocking remote connections") - protected MockChannel initiateChannel(DiscoveryNode node, ActionListener connectListener) throws IOException { + protected MockChannel initiateChannel(DiscoveryNode node) throws IOException { InetSocketAddress address = node.getAddress().address(); final MockSocket socket = new MockSocket(); final MockChannel channel = new MockChannel(socket, address, "none"); @@ -176,17 +175,16 @@ protected MockChannel initiateChannel(DiscoveryNode node, ActionListener c if (success == false) { IOUtils.close(socket); } - } executor.submit(() -> { try { socket.connect(address); socket.setSoLinger(false, 0); + channel.connectFuture.complete(null); channel.loopRead(executor); - connectListener.onResponse(null); } catch (Exception ex) { - connectListener.onFailure(ex); + channel.connectFuture.completeExceptionally(ex); } }); @@ -238,6 +236,7 @@ public final class MockChannel implements Closeable, TcpChannel, TcpServerChanne private final String profile; private final CancellableThreads cancellableThreads = new CancellableThreads(); private final CompletableContext closeFuture = new CompletableContext<>(); + private final CompletableContext connectFuture = new CompletableContext<>(); /** * Constructs a new MockChannel instance intended for handling the actual incoming / outgoing traffic. @@ -386,12 +385,16 @@ public void addCloseListener(ActionListener listener) { closeFuture.addListener(ActionListener.toBiConsumer(listener)); } + @Override + public void addConnectListener(ActionListener listener) { + connectFuture.addListener(ActionListener.toBiConsumer(listener)); + } + @Override public void setSoLinger(int value) throws IOException { if (activeChannel != null && activeChannel.isClosed() == false) { activeChannel.setSoLinger(true, value); } - } @Override @@ -452,10 +455,5 @@ protected void stopInternal() { assert openChannels.isEmpty() : "there are still open channels: " + openChannels; } } - - @Override - protected Version getCurrentVersion() { - return mockVersion; - } } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java index dc08fbf257d66..0fb6114207361 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java @@ -21,6 +21,7 @@ import org.apache.logging.log4j.message.ParameterizedMessage; import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.bytes.BytesReference; @@ -69,10 +70,10 @@ public class MockNioTransport extends TcpTransport { private volatile NioGroup nioGroup; private volatile MockTcpChannelFactory clientChannelFactory; - MockNioTransport(Settings settings, ThreadPool threadPool, NetworkService networkService, BigArrays bigArrays, + MockNioTransport(Settings settings, Version version, ThreadPool threadPool, NetworkService networkService, BigArrays bigArrays, PageCacheRecycler pageCacheRecycler, NamedWriteableRegistry namedWriteableRegistry, CircuitBreakerService circuitBreakerService) { - super("mock-nio", settings, threadPool, bigArrays, circuitBreakerService, namedWriteableRegistry, networkService); + super("mock-nio", settings, version, threadPool, bigArrays, circuitBreakerService, namedWriteableRegistry, networkService); this.pageCacheRecycler = pageCacheRecycler; } @@ -83,11 +84,9 @@ protected MockServerChannel bind(String name, InetSocketAddress address) throws } @Override - protected MockSocketChannel initiateChannel(DiscoveryNode node, ActionListener connectListener) throws IOException { + protected MockSocketChannel initiateChannel(DiscoveryNode node) throws IOException { InetSocketAddress address = node.getAddress().address(); - MockSocketChannel channel = nioGroup.openChannel(address, clientChannelFactory); - channel.addConnectListener(ActionListener.toBiConsumer(connectListener)); - return channel; + return nioGroup.openChannel(address, clientChannelFactory); } @Override @@ -272,6 +271,11 @@ public void addCloseListener(ActionListener listener) { addCloseListener(ActionListener.toBiConsumer(listener)); } + @Override + public void addConnectListener(ActionListener listener) { + addConnectListener(ActionListener.toBiConsumer(listener)); + } + @Override public void setSoLinger(int value) throws IOException { SocketChannel rawChannel = getRawChannel(); diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransportPlugin.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransportPlugin.java index 1acd947d5aad2..ceabe72ee4436 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransportPlugin.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransportPlugin.java @@ -18,6 +18,7 @@ */ package org.elasticsearch.transport.nio; +import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.network.NetworkService; import org.elasticsearch.common.settings.Settings; @@ -44,7 +45,7 @@ public Map> getTransports(Settings settings, ThreadP NamedWriteableRegistry namedWriteableRegistry, NetworkService networkService) { return Collections.singletonMap(MOCK_NIO_TRANSPORT_NAME, - () -> new MockNioTransport(settings, threadPool, networkService, bigArrays, pageCacheRecycler, namedWriteableRegistry, - circuitBreakerService)); + () -> new MockNioTransport(settings, Version.CURRENT, threadPool, networkService, bigArrays, pageCacheRecycler, + namedWriteableRegistry, circuitBreakerService)); } } diff --git a/test/framework/src/test/java/org/elasticsearch/transport/MockTcpTransportTests.java b/test/framework/src/test/java/org/elasticsearch/transport/MockTcpTransportTests.java index e8b5f38b88df1..1e5c6092687a6 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/MockTcpTransportTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/MockTcpTransportTests.java @@ -19,6 +19,7 @@ package org.elasticsearch.transport; import org.elasticsearch.Version; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.network.NetworkService; @@ -29,7 +30,6 @@ import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; import org.elasticsearch.test.transport.MockTransportService; -import java.io.IOException; import java.util.Collections; public class MockTcpTransportTests extends AbstractSimpleTransportTestCase { @@ -39,13 +39,13 @@ protected MockTransportService build(Settings settings, Version version, Cluster NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList()); Transport transport = new MockTcpTransport(settings, threadPool, BigArrays.NON_RECYCLING_INSTANCE, new NoneCircuitBreakerService(), namedWriteableRegistry, new NetworkService(Collections.emptyList()), version) { + @Override - public Version executeHandshake(DiscoveryNode node, TcpChannel mockChannel, TimeValue timeout) throws IOException, - InterruptedException { + public void executeHandshake(DiscoveryNode node, TcpChannel channel, TimeValue timeout, ActionListener listener) { if (doHandshake) { - return super.executeHandshake(node, mockChannel, timeout); + super.executeHandshake(node, channel, timeout, listener); } else { - return version.minimumCompatibilityVersion(); + listener.onResponse(version.minimumCompatibilityVersion()); } } }; diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/SimpleMockNioTransportTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/SimpleMockNioTransportTests.java index 10f089e855a5d..c6ba13d4ca7b1 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/SimpleMockNioTransportTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/SimpleMockNioTransportTests.java @@ -20,6 +20,7 @@ package org.elasticsearch.transport.nio; import org.elasticsearch.Version; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.network.NetworkService; @@ -57,25 +58,17 @@ public static MockTransportService nioFromThreadPool(Settings settings, ThreadPo ClusterSettings clusterSettings, boolean doHandshake) { NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList()); NetworkService networkService = new NetworkService(Collections.emptyList()); - Transport transport = new MockNioTransport(settings, threadPool, - networkService, BigArrays.NON_RECYCLING_INSTANCE, new MockPageCacheRecycler(settings), namedWriteableRegistry, - new NoneCircuitBreakerService()) { + Transport transport = new MockNioTransport(settings, version, threadPool, networkService, BigArrays.NON_RECYCLING_INSTANCE, + new MockPageCacheRecycler(settings), namedWriteableRegistry, new NoneCircuitBreakerService()) { @Override - public Version executeHandshake(DiscoveryNode node, TcpChannel channel, TimeValue timeout) throws IOException, - InterruptedException { + public void executeHandshake(DiscoveryNode node, TcpChannel channel, TimeValue timeout, ActionListener listener) { if (doHandshake) { - return super.executeHandshake(node, channel, timeout); + super.executeHandshake(node, channel, timeout, listener); } else { - return version.minimumCompatibilityVersion(); + listener.onResponse(version.minimumCompatibilityVersion()); } } - - @Override - protected Version getCurrentVersion() { - return version; - } - }; MockTransportService mockTransportService = MockTransportService.createNewService(settings, transport, version, threadPool, clusterSettings, Collections.emptySet()); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java index 1d11f3df1721d..b774b2990e096 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core; +import org.elasticsearch.Version; import org.elasticsearch.action.Action; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.cluster.ClusterState; @@ -470,7 +471,7 @@ public Map> getTransports( } catch (Exception e) { throw new RuntimeException(e); } - return Collections.singletonMap(SecurityField.NAME4, () -> new SecurityNetty4Transport(settings, threadPool, + return Collections.singletonMap(SecurityField.NAME4, () -> new SecurityNetty4Transport(settings, Version.CURRENT, threadPool, networkService, bigArrays, namedWriteableRegistry, circuitBreakerService, sslService)); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/transport/netty4/SecurityNetty4Transport.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/transport/netty4/SecurityNetty4Transport.java index e76302aebb058..d135506c1f427 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/transport/netty4/SecurityNetty4Transport.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/transport/netty4/SecurityNetty4Transport.java @@ -12,6 +12,7 @@ import io.netty.channel.ChannelPromise; import io.netty.handler.ssl.SslHandler; import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.Version; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.network.CloseableChannel; @@ -54,13 +55,14 @@ public class SecurityNetty4Transport extends Netty4Transport { public SecurityNetty4Transport( final Settings settings, + final Version version, final ThreadPool threadPool, final NetworkService networkService, final BigArrays bigArrays, final NamedWriteableRegistry namedWriteableRegistry, final CircuitBreakerService circuitBreakerService, final SSLService sslService) { - super(settings, threadPool, networkService, bigArrays, namedWriteableRegistry, circuitBreakerService); + super(settings, version, threadPool, networkService, bigArrays, namedWriteableRegistry, circuitBreakerService); this.sslService = sslService; this.sslEnabled = XPackSettings.TRANSPORT_SSL_ENABLED.get(settings); if (sslEnabled) { diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java index ae0b34dde8cdc..8b48843b1fffc 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java @@ -899,9 +899,9 @@ public Map> getTransports(Settings settings, ThreadP } Map> transports = new HashMap<>(); - transports.put(SecurityField.NAME4, () -> new SecurityNetty4ServerTransport(settings, threadPool, + transports.put(SecurityField.NAME4, () -> new SecurityNetty4ServerTransport(settings, Version.CURRENT, threadPool, networkService, bigArrays, namedWriteableRegistry, circuitBreakerService, ipFilter.get(), getSslService())); - transports.put(SecurityField.NIO, () -> new SecurityNioTransport(settings, threadPool, + transports.put(SecurityField.NIO, () -> new SecurityNioTransport(settings, Version.CURRENT, threadPool, networkService, bigArrays, pageCacheRecycler, namedWriteableRegistry, circuitBreakerService, ipFilter.get(), getSslService())); return Collections.unmodifiableMap(transports); diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4ServerTransport.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4ServerTransport.java index e0794d037e33d..d74aa65e94bee 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4ServerTransport.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4ServerTransport.java @@ -7,6 +7,7 @@ import io.netty.channel.Channel; import io.netty.channel.ChannelHandler; +import org.elasticsearch.Version; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.network.NetworkService; @@ -25,6 +26,7 @@ public class SecurityNetty4ServerTransport extends SecurityNetty4Transport { public SecurityNetty4ServerTransport( final Settings settings, + final Version version, final ThreadPool threadPool, final NetworkService networkService, final BigArrays bigArrays, @@ -32,7 +34,7 @@ public SecurityNetty4ServerTransport( final CircuitBreakerService circuitBreakerService, @Nullable final IPFilter authenticator, final SSLService sslService) { - super(settings, threadPool, networkService, bigArrays, namedWriteableRegistry, circuitBreakerService, sslService); + super(settings, version, threadPool, networkService, bigArrays, namedWriteableRegistry, circuitBreakerService, sslService); this.authenticator = authenticator; } diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioTransport.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioTransport.java index 71e14696a11ff..d9e4080865e44 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioTransport.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioTransport.java @@ -6,6 +6,7 @@ package org.elasticsearch.xpack.security.transport.nio; import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.Version; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.network.CloseableChannel; @@ -64,11 +65,11 @@ public class SecurityNioTransport extends NioTransport { private final Map profileConfiguration; private final boolean sslEnabled; - public SecurityNioTransport(Settings settings, ThreadPool threadPool, NetworkService networkService, BigArrays bigArrays, - PageCacheRecycler pageCacheRecycler, NamedWriteableRegistry namedWriteableRegistry, + public SecurityNioTransport(Settings settings, Version version, ThreadPool threadPool, NetworkService networkService, + BigArrays bigArrays, PageCacheRecycler pageCacheRecycler, NamedWriteableRegistry namedWriteableRegistry, CircuitBreakerService circuitBreakerService, @Nullable final IPFilter authenticator, SSLService sslService) { - super(settings, threadPool, networkService, bigArrays, pageCacheRecycler, namedWriteableRegistry, circuitBreakerService); + super(settings, version, threadPool, networkService, bigArrays, pageCacheRecycler, namedWriteableRegistry, circuitBreakerService); this.authenticator = authenticator; this.sslService = sslService; this.sslEnabled = XPackSettings.TRANSPORT_SSL_ENABLED.get(settings); diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/AbstractSimpleSecurityTransportTestCase.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/AbstractSimpleSecurityTransportTestCase.java index 077edf22c91ca..3b98bc8aa5f98 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/AbstractSimpleSecurityTransportTestCase.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/AbstractSimpleSecurityTransportTestCase.java @@ -7,6 +7,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import org.elasticsearch.Version; +import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.SuppressForbidden; import org.elasticsearch.common.io.stream.OutputStreamStreamOutput; @@ -106,7 +107,7 @@ public void testBindUnavailableAddress() { } @Override - public void testTcpHandshake() throws IOException, InterruptedException { + public void testTcpHandshake() throws InterruptedException { assumeTrue("only tcp transport has a handshake method", serviceA.getOriginalTransport() instanceof TcpTransport); TcpTransport originalTransport = (TcpTransport) serviceA.getOriginalTransport(); @@ -115,9 +116,10 @@ public void testTcpHandshake() throws IOException, InterruptedException { TcpTransport.NodeChannels connection = originalTransport.openConnection( new DiscoveryNode("TS_TPC", "TS_TPC", service.boundAddress().publishAddress(), emptyMap(), emptySet(), version0), connectionProfile)) { - Version version = originalTransport.executeHandshake(connection.getNode(), - connection.channel(TransportRequestOptions.Type.PING), TimeValue.timeValueSeconds(10)); - assertEquals(version, Version.CURRENT); + PlainActionFuture listener = PlainActionFuture.newFuture(); + originalTransport.executeHandshake(connection.getNode(), connection.channel(TransportRequestOptions.Type.PING), + TimeValue.timeValueSeconds(10), listener); + assertEquals(listener.actionGet(), Version.CURRENT); } } diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4ServerTransportTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4ServerTransportTests.java index e9d91f5bd2d6a..dc6bffe5c7271 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4ServerTransportTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4ServerTransportTests.java @@ -8,6 +8,7 @@ import io.netty.channel.ChannelHandler; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.ssl.SslHandler; +import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.network.NetworkService; import org.elasticsearch.common.settings.MockSecureSettings; @@ -68,6 +69,7 @@ private SecurityNetty4Transport createTransport(Settings additionalSettings) { .build(); return new SecurityNetty4ServerTransport( settings, + Version.CURRENT, mock(ThreadPool.class), new NetworkService(Collections.emptyList()), mock(BigArrays.class), diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SimpleSecurityNetty4ServerTransportTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SimpleSecurityNetty4ServerTransportTests.java index 291b39f4b05ba..8c4dcf9e2fac5 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SimpleSecurityNetty4ServerTransportTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SimpleSecurityNetty4ServerTransportTests.java @@ -13,6 +13,7 @@ import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.handler.ssl.SslHandler; import org.elasticsearch.Version; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.network.NetworkService; @@ -39,7 +40,6 @@ import javax.net.ssl.SNIServerName; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLParameters; -import java.io.IOException; import java.net.InetSocketAddress; import java.util.Collections; import java.util.EnumSet; @@ -72,25 +72,18 @@ public MockTransportService nettyFromThreadPool(Settings settings, ThreadPool th Settings settings1 = Settings.builder() .put(settings) .put("xpack.security.transport.ssl.enabled", true).build(); - Transport transport = new SecurityNetty4ServerTransport(settings1, threadPool, + Transport transport = new SecurityNetty4ServerTransport(settings1, version, threadPool, networkService, BigArrays.NON_RECYCLING_INSTANCE, namedWriteableRegistry, new NoneCircuitBreakerService(), null, createSSLService(settings1)) { @Override - public Version executeHandshake(DiscoveryNode node, TcpChannel channel, TimeValue timeout) throws IOException, - InterruptedException { + public void executeHandshake(DiscoveryNode node, TcpChannel channel, TimeValue timeout, ActionListener listener) { if (doHandshake) { - return super.executeHandshake(node, channel, timeout); + super.executeHandshake(node, channel, timeout, listener); } else { - return version.minimumCompatibilityVersion(); + listener.onResponse(version.minimumCompatibilityVersion()); } } - - @Override - protected Version getCurrentVersion() { - return version; - } - }; MockTransportService mockTransportService = MockTransportService.createNewService(settings, transport, version, threadPool, clusterSettings, diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SimpleSecurityNioTransportTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SimpleSecurityNioTransportTests.java index 7fd4d8b5e0319..5f336e2e5d38c 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SimpleSecurityNioTransportTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SimpleSecurityNioTransportTests.java @@ -6,6 +6,7 @@ package org.elasticsearch.xpack.security.transport.nio; import org.elasticsearch.Version; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.network.NetworkService; @@ -22,7 +23,6 @@ import org.elasticsearch.transport.Transport; import org.elasticsearch.xpack.security.transport.AbstractSimpleSecurityTransportTestCase; -import java.io.IOException; import java.util.Collections; public class SimpleSecurityNioTransportTests extends AbstractSimpleSecurityTransportTestCase { @@ -34,25 +34,18 @@ public MockTransportService nioFromThreadPool(Settings settings, ThreadPool thre Settings settings1 = Settings.builder() .put(settings) .put("xpack.security.transport.ssl.enabled", true).build(); - Transport transport = new SecurityNioTransport(settings1, threadPool, - networkService, BigArrays.NON_RECYCLING_INSTANCE, new MockPageCacheRecycler(settings), namedWriteableRegistry, - new NoneCircuitBreakerService(), null, createSSLService(settings1)) { + Transport transport = new SecurityNioTransport(settings1, version, threadPool, networkService, BigArrays.NON_RECYCLING_INSTANCE, + new MockPageCacheRecycler(settings), namedWriteableRegistry, new NoneCircuitBreakerService(), null, + createSSLService(settings1)) { @Override - public Version executeHandshake(DiscoveryNode node, TcpChannel channel, TimeValue timeout) throws IOException, - InterruptedException { + public void executeHandshake(DiscoveryNode node, TcpChannel channel, TimeValue timeout, ActionListener listener) { if (doHandshake) { - return super.executeHandshake(node, channel, timeout); + super.executeHandshake(node, channel, timeout, listener); } else { - return version.minimumCompatibilityVersion(); + listener.onResponse(version.minimumCompatibilityVersion()); } } - - @Override - protected Version getCurrentVersion() { - return version; - } - }; MockTransportService mockTransportService = MockTransportService.createNewService(settings, transport, version, threadPool, clusterSettings,