From 2e658c83cbdaf5a42759be48ede9705134f40f9a Mon Sep 17 00:00:00 2001 From: Jason Tedor Date: Mon, 9 Oct 2017 13:36:50 -0400 Subject: [PATCH] Check for closed connection while opening While opening a connection to a node, a channel can subsequently close. If this happens, a future callback whose purpose is to close all other channels and disconnect from the node will fire. However, this future will not be ready to close all the channels because the connection will not be exposed to the future callback yet. Since this callback is run once, we will never try to disconnect from this node again and we will be left with a closed channel. This commit adds a check that all channels are open before exposing the channel and throws a general connection exception. In this case, the usual connection retry logic will take over. --- .../elasticsearch/transport/TcpTransport.java | 13 +++++- .../transport/TCPTransportTests.java | 12 ++++- .../transport/netty4/Netty4Transport.java | 4 +- .../netty4/SimpleNetty4TransportTests.java | 26 +++++++++-- .../AbstractSimpleTransportTestCase.java | 46 +++++++++++++++---- .../transport/MockTcpTransport.java | 5 +- .../transport/nio/NioClient.java | 6 ++- .../transport/nio/NioTransport.java | 5 +- .../transport/MockTcpTransportTests.java | 23 +++++++++- .../transport/nio/NioClientTests.java | 10 ++-- .../nio/SimpleNioTransportTests.java | 30 ++++++++++-- 11 files changed, 145 insertions(+), 35 deletions(-) diff --git a/core/src/main/java/org/elasticsearch/transport/TcpTransport.java b/core/src/main/java/org/elasticsearch/transport/TcpTransport.java index 6bf731f2936d9..c217aa98de370 100644 --- a/core/src/main/java/org/elasticsearch/transport/TcpTransport.java +++ b/core/src/main/java/org/elasticsearch/transport/TcpTransport.java @@ -588,7 +588,10 @@ public final NodeChannels openConnection(DiscoveryNode node, ConnectionProfile c } } }; - nodeChannels = connectToChannels(node, connectionProfile, onClose); + nodeChannels = connectToChannels(node, connectionProfile, this::onChannelOpen, onClose); + if (!Arrays.stream(nodeChannels.channels).allMatch(this::isOpen)) { + throw new ConnectTransportException(node, "a channel closed while connecting"); + } final Channel channel = nodeChannels.getChannels().get(0); // one channel is guaranteed by the connection profile final TimeValue connectTimeout = connectionProfile.getConnectTimeout() == null ? defaultConnectionProfile.getConnectTimeout() : @@ -617,6 +620,10 @@ public final NodeChannels openConnection(DiscoveryNode node, ConnectionProfile c } } + protected void onChannelOpen(final Channel channel) { + + } + private void disconnectFromNodeCloseAndNotify(DiscoveryNode node, NodeChannels nodeChannels) { assert nodeChannels != null : "nodeChannels must not be null"; try { @@ -1034,7 +1041,9 @@ protected void innerOnFailure(Exception e) { */ protected abstract void sendMessage(Channel channel, BytesReference reference, ActionListener listener); - protected abstract NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile connectionProfile, + protected abstract NodeChannels connectToChannels(DiscoveryNode node, + ConnectionProfile connectionProfile, + Consumer onChannelOpen, Consumer onChannelClose) throws IOException; /** diff --git a/core/src/test/java/org/elasticsearch/transport/TCPTransportTests.java b/core/src/test/java/org/elasticsearch/transport/TCPTransportTests.java index 55457cc8ae431..fa1f2dddfb56d 100644 --- a/core/src/test/java/org/elasticsearch/transport/TCPTransportTests.java +++ b/core/src/test/java/org/elasticsearch/transport/TCPTransportTests.java @@ -37,6 +37,7 @@ import java.io.IOException; import java.net.InetSocketAddress; +import java.nio.channels.Channel; import java.util.List; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; @@ -224,9 +225,16 @@ protected void sendMessage(Object o, BytesReference reference, ActionListener li } @Override - protected NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile profile, + protected NodeChannels connectToChannels(DiscoveryNode node, + ConnectionProfile profile, + Consumer onChannelOpen, Consumer onChannelClose) throws IOException { - return new NodeChannels(node, new Object[profile.getNumConnections()], profile); + + final Object[] objects = new Object[profile.getNumConnections()]; + for (int i = 0; i < objects.length; i++) { + onChannelOpen.accept(objects[i]); + } + return new NodeChannels(node, objects, profile); } @Override diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java index 84c86bd2d770a..a7d8dc4679b32 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java @@ -252,7 +252,8 @@ public long getNumOpenServerConnections() { } @Override - protected NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile profile, Consumer onChannelClose) { + protected NodeChannels connectToChannels( + DiscoveryNode node, ConnectionProfile profile, Consumer onChannelOpen, Consumer onChannelClose) { final Channel[] channels = new Channel[profile.getNumConnections()]; final NodeChannels nodeChannels = new NodeChannels(node, channels, profile); boolean success = false; @@ -283,6 +284,7 @@ protected NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile p if (!future.isSuccess()) { throw new ConnectTransportException(node, "connect_timeout[" + connectTimeout + "]", future.cause()); } + onChannelOpen.accept(future.channel()); channels[i] = future.channel(); channels[i].closeFuture().addListener(closeListener); } diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/SimpleNetty4TransportTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/SimpleNetty4TransportTests.java index 92c21f942c292..d094cdb987be4 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/SimpleNetty4TransportTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/SimpleNetty4TransportTests.java @@ -20,6 +20,7 @@ package org.elasticsearch.transport.netty4; import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; import org.elasticsearch.Version; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; @@ -44,15 +45,16 @@ import java.net.InetAddress; import java.net.UnknownHostException; import java.util.Collections; +import java.util.function.Consumer; import static java.util.Collections.emptyMap; import static java.util.Collections.emptySet; import static org.hamcrest.Matchers.containsString; -public class SimpleNetty4TransportTests extends AbstractSimpleTransportTestCase { +public class SimpleNetty4TransportTests extends AbstractSimpleTransportTestCase { public static MockTransportService nettyFromThreadPool(Settings settings, ThreadPool threadPool, final Version version, - ClusterSettings clusterSettings, boolean doHandshake) { + ClusterSettings clusterSettings, boolean doHandshake, Consumer onChannelOpen) { NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList()); Transport transport = new Netty4Transport(settings, threadPool, new NetworkService(Collections.emptyList()), BigArrays.NON_RECYCLING_INSTANCE, namedWriteableRegistry, new NoneCircuitBreakerService()) { @@ -67,6 +69,11 @@ protected Version executeHandshake(DiscoveryNode node, Channel channel, TimeValu } } + @Override + protected void onChannelOpen(Channel channel) { + onChannelOpen.accept(channel); + } + @Override protected Version getCurrentVersion() { return version; @@ -79,13 +86,21 @@ protected Version getCurrentVersion() { } @Override - protected MockTransportService build(Settings settings, Version version, ClusterSettings clusterSettings, boolean doHandshake) { + protected MockTransportService build( + Settings settings, Version version, ClusterSettings clusterSettings, boolean doHandshake, Consumer onChannelOpen) { settings = Settings.builder().put(settings).put(TcpTransport.PORT.getKey(), "0").build(); - MockTransportService transportService = nettyFromThreadPool(settings, threadPool, version, clusterSettings, doHandshake); + MockTransportService transportService = + nettyFromThreadPool(settings, threadPool, version, clusterSettings, doHandshake, onChannelOpen); transportService.start(); return transportService; } + @Override + protected void close(Channel channel) { + final ChannelFuture future = channel.close(); + future.awaitUninterruptibly(); + } + public void testConnectException() throws UnknownHostException { try { serviceA.connectToNode(new DiscoveryNode("C", new TransportAddress(InetAddress.getByName("localhost"), 9876), @@ -108,7 +123,8 @@ public void testBindUnavailableAddress() { .build(); ClusterSettings clusterSettings = new ClusterSettings(settings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); BindTransportException bindTransportException = expectThrows(BindTransportException.class, () -> { - MockTransportService transportService = nettyFromThreadPool(settings, threadPool, Version.CURRENT, clusterSettings, true); + MockTransportService transportService = + nettyFromThreadPool(settings, threadPool, Version.CURRENT, clusterSettings, true, c -> {}); try { transportService.start(); } finally { diff --git a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java index da43f116d4245..d170ba64f79f3 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java @@ -79,17 +79,20 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; import java.util.stream.Collectors; import static java.util.Collections.emptyMap; import static java.util.Collections.emptySet; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasToString; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.startsWith; -public abstract class AbstractSimpleTransportTestCase extends ESTestCase { +public abstract class AbstractSimpleTransportTestCase extends ESTestCase { protected ThreadPool threadPool; // we use always a non-alpha or beta version here otherwise minimumCompatibilityVersion will be different for the two used versions @@ -105,7 +108,8 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase { protected volatile DiscoveryNode nodeB; protected volatile MockTransportService serviceB; - protected abstract MockTransportService build(Settings settings, Version version, ClusterSettings clusterSettings, boolean doHandshake); + protected abstract MockTransportService build( + Settings settings, Version version, ClusterSettings clusterSettings, boolean doHandshake, Consumer onChannelOpen); @Override @Before @@ -146,6 +150,12 @@ public void onNodeDisconnected(DiscoveryNode node) { private MockTransportService buildService(final String name, final Version version, ClusterSettings clusterSettings, Settings settings, boolean acceptRequests, boolean doHandshake) { + return buildService(name, version, clusterSettings, settings, acceptRequests, doHandshake, c -> {}); + } + + private MockTransportService buildService(final String name, final Version version, ClusterSettings clusterSettings, + Settings settings, boolean acceptRequests, boolean doHandshake, + Consumer onChannelOpen) { MockTransportService service = build( Settings.builder() .put(settings) @@ -154,7 +164,7 @@ private MockTransportService buildService(final String name, final Version versi .put(TransportService.TRACE_LOG_EXCLUDE_SETTING.getKey(), "NOTHING") .build(), version, - clusterSettings, doHandshake); + clusterSettings, doHandshake, onChannelOpen); if (acceptRequests) { service.acceptIncomingRequests(); } @@ -1692,7 +1702,7 @@ public void testSendRandomRequests() throws InterruptedException { .put(TransportService.TRACE_LOG_EXCLUDE_SETTING.getKey(), "NOTHING") .build(), version0, - null, true); + null, true, c -> {}); DiscoveryNode nodeC = serviceC.getLocalNode(); serviceC.acceptIncomingRequests(); @@ -2125,7 +2135,7 @@ public String executor() { public void testHandlerIsInvokedOnConnectionClose() throws IOException, InterruptedException { List executors = new ArrayList<>(ThreadPool.THREAD_POOL_TYPES.keySet()); CollectionUtil.timSort(executors); // makes sure it's reproducible - TransportService serviceC = build(Settings.builder().put("name", "TS_TEST").build(), version0, null, true); + TransportService serviceC = build(Settings.builder().put("name", "TS_TEST").build(), version0, null, true, c -> {}); serviceC.registerRequestHandler("action", TestRequest::new, ThreadPool.Names.SAME, (request, channel) -> { // do nothing @@ -2183,7 +2193,7 @@ public String executor() { } public void testConcurrentDisconnectOnNonPublishedConnection() throws IOException, InterruptedException { - MockTransportService serviceC = build(Settings.builder().put("name", "TS_TEST").build(), version0, null, true); + MockTransportService serviceC = build(Settings.builder().put("name", "TS_TEST").build(), version0, null, true, c -> {}); CountDownLatch receivedLatch = new CountDownLatch(1); CountDownLatch sendResponseLatch = new CountDownLatch(1); serviceC.registerRequestHandler("action", TestRequest::new, ThreadPool.Names.SAME, @@ -2251,7 +2261,7 @@ public String executor() { } public void testTransportStats() throws Exception { - MockTransportService serviceC = build(Settings.builder().put("name", "TS_TEST").build(), version0, null, true); + MockTransportService serviceC = build(Settings.builder().put("name", "TS_TEST").build(), version0, null, true, c -> {}); CountDownLatch receivedLatch = new CountDownLatch(1); CountDownLatch sendResponseLatch = new CountDownLatch(1); serviceB.registerRequestHandler("action", TestRequest::new, ThreadPool.Names.SAME, @@ -2344,7 +2354,7 @@ public String executor() { } public void testTransportStatsWithException() throws Exception { - MockTransportService serviceC = build(Settings.builder().put("name", "TS_TEST").build(), version0, null, true); + MockTransportService serviceC = build(Settings.builder().put("name", "TS_TEST").build(), version0, null, true, c -> {}); CountDownLatch receivedLatch = new CountDownLatch(1); CountDownLatch sendResponseLatch = new CountDownLatch(1); Exception ex = new RuntimeException("boom"); @@ -2457,7 +2467,7 @@ public void testTransportProfilesWithPortAndHost() { .put("transport.profiles.some_other_profile.port", "8700-8800") .putArray("transport.profiles.some_other_profile.bind_host", hosts) .putArray("transport.profiles.some_other_profile.publish_host", "_local:ipv4_") - .build(), version0, null, true)) { + .build(), version0, null, true, c -> {})) { serviceC.start(); serviceC.acceptIncomingRequests(); @@ -2612,4 +2622,22 @@ public void testProfilesIncludesDefault() { assertEquals(new HashSet<>(Arrays.asList("default", "test")), profileSettings.stream().map(s -> s.profileName).collect(Collectors .toSet())); } + + public void testChannelCloseWhileConnecting() throws IOException { + final MockTransportService service = buildService("service", version0, clusterSettings, Settings.EMPTY, true, true, this::close); + final TcpTransport underlyingTransport = (TcpTransport) service.getOriginalTransport(); + + final String otherName = "other_service"; + try (TransportService otherService = buildService(otherName, Version.CURRENT, null)) { + final DiscoveryNode node = + new DiscoveryNode(otherName, otherName, otherService.boundAddress().publishAddress(), emptyMap(), emptySet(), version0); + final ConnectTransportException e = + expectThrows(ConnectTransportException.class, () -> underlyingTransport.openConnection(node, null)); + assertThat(e, hasToString(containsString("a channel closed while connecting"))); + } + service.close(); + } + + protected abstract void close(Channel channel); + } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java b/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java index 29ff4219feecb..ef39b435993f3 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java @@ -176,7 +176,9 @@ private void readMessage(MockChannel mockChannel, StreamInput input) throws IOEx } @Override - protected NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile profile, + protected NodeChannels connectToChannels(DiscoveryNode node, + ConnectionProfile profile, + Consumer onChannelOpen, Consumer onChannelClose) throws IOException { final MockChannel[] mockChannels = new MockChannel[1]; final NodeChannels nodeChannels = new NodeChannels(node, mockChannels, LIGHT_PROFILE); // we always use light here @@ -193,6 +195,7 @@ protected NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile p throw new ConnectTransportException(node, "connect_timeout[" + connectTimeout + "]", ex); } MockChannel channel = new MockChannel(socket, address, "none", onChannelClose); + onChannelOpen.accept(channel); channel.loopRead(executor); mockChannels[0] = channel; success = true; diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/NioClient.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/NioClient.java index f846c53abdec6..7b02cfec2d792 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/NioClient.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/NioClient.java @@ -56,7 +56,10 @@ public NioClient(Logger logger, OpenChannels openChannels, Supplier onChannelOpen, Consumer closeListener) throws IOException { boolean allowedToConnect = semaphore.tryAcquire(); if (allowedToConnect == false) { @@ -70,6 +73,7 @@ public boolean connectToChannels(DiscoveryNode node, NioSocketChannel[] channels for (int i = 0; i < channels.length; i++) { SocketSelector selector = selectorSupplier.get(); NioSocketChannel nioSocketChannel = channelFactory.openNioChannel(address, selector, closeListener); + onChannelOpen.accept(nioSocketChannel); openChannels.clientChannelOpened(nioSocketChannel); connections.add(nioSocketChannel); } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/NioTransport.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/NioTransport.java index 606225fd02ad7..21e165267cf21 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/NioTransport.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/NioTransport.java @@ -151,11 +151,12 @@ protected void sendMessage(NioChannel channel, BytesReference reference, ActionL } @Override - protected NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile profile, Consumer onChannelClose) + protected NodeChannels connectToChannels( + DiscoveryNode node, ConnectionProfile profile, Consumer onChannelOpen, Consumer onChannelClose) throws IOException { NioSocketChannel[] channels = new NioSocketChannel[profile.getNumConnections()]; ClientChannelCloseListener closeListener = new ClientChannelCloseListener(onChannelClose); - boolean connected = client.connectToChannels(node, channels, profile.getConnectTimeout(), closeListener); + boolean connected = client.connectToChannels(node, channels, profile.getConnectTimeout(), onChannelOpen, closeListener); if (connected == false) { throw new ElasticsearchException("client is shutdown"); } diff --git a/test/framework/src/test/java/org/elasticsearch/transport/MockTcpTransportTests.java b/test/framework/src/test/java/org/elasticsearch/transport/MockTcpTransportTests.java index b32680d9da466..8a3af7633dd43 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/MockTcpTransportTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/MockTcpTransportTests.java @@ -30,11 +30,15 @@ import org.elasticsearch.test.transport.MockTransportService; import java.io.IOException; +import java.io.UncheckedIOException; import java.util.Collections; +import java.util.function.Consumer; + +public class MockTcpTransportTests extends AbstractSimpleTransportTestCase { -public class MockTcpTransportTests extends AbstractSimpleTransportTestCase { @Override - protected MockTransportService build(Settings settings, Version version, ClusterSettings clusterSettings, boolean doHandshake) { + protected MockTransportService build(Settings settings, Version version, ClusterSettings clusterSettings, boolean doHandshake, + Consumer onChannelOpen) { NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList()); Transport transport = new MockTcpTransport(settings, threadPool, BigArrays.NON_RECYCLING_INSTANCE, new NoneCircuitBreakerService(), namedWriteableRegistry, new NetworkService(Collections.emptyList()), version) { @@ -47,10 +51,25 @@ protected Version executeHandshake(DiscoveryNode node, MockChannel mockChannel, return version.minimumCompatibilityVersion(); } } + + @Override + protected void onChannelOpen(MockChannel mockChannel) { + onChannelOpen.accept(mockChannel); + } }; MockTransportService mockTransportService = MockTransportService.createNewService(Settings.EMPTY, transport, version, threadPool, clusterSettings); mockTransportService.start(); return mockTransportService; } + + @Override + protected void close(MockTcpTransport.MockChannel mockChannel) { + try { + mockChannel.close(); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + } diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/NioClientTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/NioClientTests.java index 4cae51acc83fa..13d3f22443982 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/NioClientTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/NioClientTests.java @@ -84,7 +84,7 @@ public void testCreateConnections() throws IOException, InterruptedException { when(connectFuture1.awaitConnectionComplete(5, TimeUnit.MILLISECONDS)).thenReturn(true); when(connectFuture2.awaitConnectionComplete(5, TimeUnit.MILLISECONDS)).thenReturn(true); - client.connectToChannels(node, channels, TimeValue.timeValueMillis(5), listener); + client.connectToChannels(node, channels, TimeValue.timeValueMillis(5), c -> {}, listener); assertEquals(channel1, channels[0]); assertEquals(channel2, channels[1]); @@ -99,7 +99,7 @@ public void testWithADifferentConnectTimeout() throws IOException, InterruptedEx when(connectFuture1.awaitConnectionComplete(3, TimeUnit.MILLISECONDS)).thenReturn(true); channels = new NioSocketChannel[1]; - client.connectToChannels(node, channels, TimeValue.timeValueMillis(3), listener); + client.connectToChannels(node, channels, TimeValue.timeValueMillis(3), c -> {}, listener); assertEquals(channel1, channels[0]); } @@ -121,7 +121,7 @@ public void testConnectionTimeout() throws IOException, InterruptedException { when(connectFuture2.awaitConnectionComplete(5, TimeUnit.MILLISECONDS)).thenReturn(false); try { - client.connectToChannels(node, channels, TimeValue.timeValueMillis(5), listener); + client.connectToChannels(node, channels, TimeValue.timeValueMillis(5), c -> {}, listener); fail("Should have thrown ConnectTransportException"); } catch (ConnectTransportException e) { assertTrue(e.getMessage().contains("connect_timeout[5ms]")); @@ -149,7 +149,7 @@ public void testConnectionException() throws IOException, InterruptedException { when(connectFuture2.getException()).thenReturn(ioException); try { - client.connectToChannels(node, channels, TimeValue.timeValueMillis(5), listener); + client.connectToChannels(node, channels, TimeValue.timeValueMillis(5), c -> {}, listener); fail("Should have thrown ConnectTransportException"); } catch (ConnectTransportException e) { assertTrue(e.getMessage().contains("connect_exception")); @@ -166,7 +166,7 @@ public void testConnectionException() throws IOException, InterruptedException { public void testCloseDoesNotAllowConnections() throws IOException { client.close(); - assertFalse(client.connectToChannels(node, channels, TimeValue.timeValueMillis(5), listener)); + assertFalse(client.connectToChannels(node, channels, TimeValue.timeValueMillis(5), c -> {}, listener)); for (NioSocketChannel channel : channels) { assertNull(channel); diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/SimpleNioTransportTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/SimpleNioTransportTests.java index 2ba2e4cc02a85..74b43c9a1f034 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/SimpleNioTransportTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/SimpleNioTransportTests.java @@ -41,19 +41,22 @@ import org.elasticsearch.transport.nio.channel.NioChannel; import java.io.IOException; +import java.io.UncheckedIOException; import java.net.InetAddress; import java.net.UnknownHostException; import java.util.Collections; +import java.util.function.Consumer; import static java.util.Collections.emptyMap; import static java.util.Collections.emptySet; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.instanceOf; -public class SimpleNioTransportTests extends AbstractSimpleTransportTestCase { +public class SimpleNioTransportTests extends AbstractSimpleTransportTestCase { public static MockTransportService nioFromThreadPool(Settings settings, ThreadPool threadPool, final Version version, - ClusterSettings clusterSettings, boolean doHandshake) { + ClusterSettings clusterSettings, boolean doHandshake, + Consumer onChannelOpen) { NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList()); NetworkService networkService = new NetworkService(Collections.emptyList()); Transport transport = new NioTransport(settings, threadPool, @@ -70,6 +73,11 @@ protected Version executeHandshake(DiscoveryNode node, NioChannel channel, TimeV } } + @Override + protected void onChannelOpen(NioChannel nioChannel) { + onChannelOpen.accept(nioChannel); + } + @Override protected Version getCurrentVersion() { return version; @@ -87,15 +95,26 @@ protected SocketEventHandler getSocketEventHandler() { } @Override - protected MockTransportService build(Settings settings, Version version, ClusterSettings clusterSettings, boolean doHandshake) { + protected MockTransportService build( + Settings settings, Version version, ClusterSettings clusterSettings, boolean doHandshake, Consumer onChannelOpen) { settings = Settings.builder().put(settings) .put(TcpTransport.PORT.getKey(), "0") .build(); - MockTransportService transportService = nioFromThreadPool(settings, threadPool, version, clusterSettings, doHandshake); + MockTransportService transportService = + nioFromThreadPool(settings, threadPool, version, clusterSettings, doHandshake, onChannelOpen); transportService.start(); return transportService; } + @Override + protected void close(NioChannel nioChannel) { + try { + nioChannel.getCloseFuture().awaitClose(); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + public void testConnectException() throws UnknownHostException { try { serviceA.connectToNode(new DiscoveryNode("C", new TransportAddress(InetAddress.getByName("localhost"), 9876), @@ -120,7 +139,8 @@ public void testBindUnavailableAddress() { .build(); ClusterSettings clusterSettings = new ClusterSettings(settings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); BindTransportException bindTransportException = expectThrows(BindTransportException.class, () -> { - MockTransportService transportService = nioFromThreadPool(settings, threadPool, Version.CURRENT, clusterSettings, true); + MockTransportService transportService = + nioFromThreadPool(settings, threadPool, Version.CURRENT, clusterSettings, true, c -> {}); try { transportService.start(); } finally {