diff --git a/core/src/main/java/org/elasticsearch/transport/TcpTransport.java b/core/src/main/java/org/elasticsearch/transport/TcpTransport.java index 6bf731f2936d9..c217aa98de370 100644 --- a/core/src/main/java/org/elasticsearch/transport/TcpTransport.java +++ b/core/src/main/java/org/elasticsearch/transport/TcpTransport.java @@ -588,7 +588,10 @@ public final NodeChannels openConnection(DiscoveryNode node, ConnectionProfile c } } }; - nodeChannels = connectToChannels(node, connectionProfile, onClose); + nodeChannels = connectToChannels(node, connectionProfile, this::onChannelOpen, onClose); + if (!Arrays.stream(nodeChannels.channels).allMatch(this::isOpen)) { + throw new ConnectTransportException(node, "a channel closed while connecting"); + } final Channel channel = nodeChannels.getChannels().get(0); // one channel is guaranteed by the connection profile final TimeValue connectTimeout = connectionProfile.getConnectTimeout() == null ? defaultConnectionProfile.getConnectTimeout() : @@ -617,6 +620,10 @@ public final NodeChannels openConnection(DiscoveryNode node, ConnectionProfile c } } + protected void onChannelOpen(final Channel channel) { + + } + private void disconnectFromNodeCloseAndNotify(DiscoveryNode node, NodeChannels nodeChannels) { assert nodeChannels != null : "nodeChannels must not be null"; try { @@ -1034,7 +1041,9 @@ protected void innerOnFailure(Exception e) { */ protected abstract void sendMessage(Channel channel, BytesReference reference, ActionListener listener); - protected abstract NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile connectionProfile, + protected abstract NodeChannels connectToChannels(DiscoveryNode node, + ConnectionProfile connectionProfile, + Consumer onChannelOpen, Consumer onChannelClose) throws IOException; /** diff --git a/core/src/test/java/org/elasticsearch/transport/TCPTransportTests.java b/core/src/test/java/org/elasticsearch/transport/TCPTransportTests.java index 55457cc8ae431..fa1f2dddfb56d 100644 --- a/core/src/test/java/org/elasticsearch/transport/TCPTransportTests.java +++ b/core/src/test/java/org/elasticsearch/transport/TCPTransportTests.java @@ -37,6 +37,7 @@ import java.io.IOException; import java.net.InetSocketAddress; +import java.nio.channels.Channel; import java.util.List; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; @@ -224,9 +225,16 @@ protected void sendMessage(Object o, BytesReference reference, ActionListener li } @Override - protected NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile profile, + protected NodeChannels connectToChannels(DiscoveryNode node, + ConnectionProfile profile, + Consumer onChannelOpen, Consumer onChannelClose) throws IOException { - return new NodeChannels(node, new Object[profile.getNumConnections()], profile); + + final Object[] objects = new Object[profile.getNumConnections()]; + for (int i = 0; i < objects.length; i++) { + onChannelOpen.accept(objects[i]); + } + return new NodeChannels(node, objects, profile); } @Override diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java index 84c86bd2d770a..a7d8dc4679b32 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java @@ -252,7 +252,8 @@ public long getNumOpenServerConnections() { } @Override - protected NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile profile, Consumer onChannelClose) { + protected NodeChannels connectToChannels( + DiscoveryNode node, ConnectionProfile profile, Consumer onChannelOpen, Consumer onChannelClose) { final Channel[] channels = new Channel[profile.getNumConnections()]; final NodeChannels nodeChannels = new NodeChannels(node, channels, profile); boolean success = false; @@ -283,6 +284,7 @@ protected NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile p if (!future.isSuccess()) { throw new ConnectTransportException(node, "connect_timeout[" + connectTimeout + "]", future.cause()); } + onChannelOpen.accept(future.channel()); channels[i] = future.channel(); channels[i].closeFuture().addListener(closeListener); } diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/SimpleNetty4TransportTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/SimpleNetty4TransportTests.java index 92c21f942c292..d094cdb987be4 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/SimpleNetty4TransportTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/SimpleNetty4TransportTests.java @@ -20,6 +20,7 @@ package org.elasticsearch.transport.netty4; import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; import org.elasticsearch.Version; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; @@ -44,15 +45,16 @@ import java.net.InetAddress; import java.net.UnknownHostException; import java.util.Collections; +import java.util.function.Consumer; import static java.util.Collections.emptyMap; import static java.util.Collections.emptySet; import static org.hamcrest.Matchers.containsString; -public class SimpleNetty4TransportTests extends AbstractSimpleTransportTestCase { +public class SimpleNetty4TransportTests extends AbstractSimpleTransportTestCase { public static MockTransportService nettyFromThreadPool(Settings settings, ThreadPool threadPool, final Version version, - ClusterSettings clusterSettings, boolean doHandshake) { + ClusterSettings clusterSettings, boolean doHandshake, Consumer onChannelOpen) { NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList()); Transport transport = new Netty4Transport(settings, threadPool, new NetworkService(Collections.emptyList()), BigArrays.NON_RECYCLING_INSTANCE, namedWriteableRegistry, new NoneCircuitBreakerService()) { @@ -67,6 +69,11 @@ protected Version executeHandshake(DiscoveryNode node, Channel channel, TimeValu } } + @Override + protected void onChannelOpen(Channel channel) { + onChannelOpen.accept(channel); + } + @Override protected Version getCurrentVersion() { return version; @@ -79,13 +86,21 @@ protected Version getCurrentVersion() { } @Override - protected MockTransportService build(Settings settings, Version version, ClusterSettings clusterSettings, boolean doHandshake) { + protected MockTransportService build( + Settings settings, Version version, ClusterSettings clusterSettings, boolean doHandshake, Consumer onChannelOpen) { settings = Settings.builder().put(settings).put(TcpTransport.PORT.getKey(), "0").build(); - MockTransportService transportService = nettyFromThreadPool(settings, threadPool, version, clusterSettings, doHandshake); + MockTransportService transportService = + nettyFromThreadPool(settings, threadPool, version, clusterSettings, doHandshake, onChannelOpen); transportService.start(); return transportService; } + @Override + protected void close(Channel channel) { + final ChannelFuture future = channel.close(); + future.awaitUninterruptibly(); + } + public void testConnectException() throws UnknownHostException { try { serviceA.connectToNode(new DiscoveryNode("C", new TransportAddress(InetAddress.getByName("localhost"), 9876), @@ -108,7 +123,8 @@ public void testBindUnavailableAddress() { .build(); ClusterSettings clusterSettings = new ClusterSettings(settings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); BindTransportException bindTransportException = expectThrows(BindTransportException.class, () -> { - MockTransportService transportService = nettyFromThreadPool(settings, threadPool, Version.CURRENT, clusterSettings, true); + MockTransportService transportService = + nettyFromThreadPool(settings, threadPool, Version.CURRENT, clusterSettings, true, c -> {}); try { transportService.start(); } finally { diff --git a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java index da43f116d4245..d170ba64f79f3 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java @@ -79,17 +79,20 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; import java.util.stream.Collectors; import static java.util.Collections.emptyMap; import static java.util.Collections.emptySet; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasToString; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.startsWith; -public abstract class AbstractSimpleTransportTestCase extends ESTestCase { +public abstract class AbstractSimpleTransportTestCase extends ESTestCase { protected ThreadPool threadPool; // we use always a non-alpha or beta version here otherwise minimumCompatibilityVersion will be different for the two used versions @@ -105,7 +108,8 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase { protected volatile DiscoveryNode nodeB; protected volatile MockTransportService serviceB; - protected abstract MockTransportService build(Settings settings, Version version, ClusterSettings clusterSettings, boolean doHandshake); + protected abstract MockTransportService build( + Settings settings, Version version, ClusterSettings clusterSettings, boolean doHandshake, Consumer onChannelOpen); @Override @Before @@ -146,6 +150,12 @@ public void onNodeDisconnected(DiscoveryNode node) { private MockTransportService buildService(final String name, final Version version, ClusterSettings clusterSettings, Settings settings, boolean acceptRequests, boolean doHandshake) { + return buildService(name, version, clusterSettings, settings, acceptRequests, doHandshake, c -> {}); + } + + private MockTransportService buildService(final String name, final Version version, ClusterSettings clusterSettings, + Settings settings, boolean acceptRequests, boolean doHandshake, + Consumer onChannelOpen) { MockTransportService service = build( Settings.builder() .put(settings) @@ -154,7 +164,7 @@ private MockTransportService buildService(final String name, final Version versi .put(TransportService.TRACE_LOG_EXCLUDE_SETTING.getKey(), "NOTHING") .build(), version, - clusterSettings, doHandshake); + clusterSettings, doHandshake, onChannelOpen); if (acceptRequests) { service.acceptIncomingRequests(); } @@ -1692,7 +1702,7 @@ public void testSendRandomRequests() throws InterruptedException { .put(TransportService.TRACE_LOG_EXCLUDE_SETTING.getKey(), "NOTHING") .build(), version0, - null, true); + null, true, c -> {}); DiscoveryNode nodeC = serviceC.getLocalNode(); serviceC.acceptIncomingRequests(); @@ -2125,7 +2135,7 @@ public String executor() { public void testHandlerIsInvokedOnConnectionClose() throws IOException, InterruptedException { List executors = new ArrayList<>(ThreadPool.THREAD_POOL_TYPES.keySet()); CollectionUtil.timSort(executors); // makes sure it's reproducible - TransportService serviceC = build(Settings.builder().put("name", "TS_TEST").build(), version0, null, true); + TransportService serviceC = build(Settings.builder().put("name", "TS_TEST").build(), version0, null, true, c -> {}); serviceC.registerRequestHandler("action", TestRequest::new, ThreadPool.Names.SAME, (request, channel) -> { // do nothing @@ -2183,7 +2193,7 @@ public String executor() { } public void testConcurrentDisconnectOnNonPublishedConnection() throws IOException, InterruptedException { - MockTransportService serviceC = build(Settings.builder().put("name", "TS_TEST").build(), version0, null, true); + MockTransportService serviceC = build(Settings.builder().put("name", "TS_TEST").build(), version0, null, true, c -> {}); CountDownLatch receivedLatch = new CountDownLatch(1); CountDownLatch sendResponseLatch = new CountDownLatch(1); serviceC.registerRequestHandler("action", TestRequest::new, ThreadPool.Names.SAME, @@ -2251,7 +2261,7 @@ public String executor() { } public void testTransportStats() throws Exception { - MockTransportService serviceC = build(Settings.builder().put("name", "TS_TEST").build(), version0, null, true); + MockTransportService serviceC = build(Settings.builder().put("name", "TS_TEST").build(), version0, null, true, c -> {}); CountDownLatch receivedLatch = new CountDownLatch(1); CountDownLatch sendResponseLatch = new CountDownLatch(1); serviceB.registerRequestHandler("action", TestRequest::new, ThreadPool.Names.SAME, @@ -2344,7 +2354,7 @@ public String executor() { } public void testTransportStatsWithException() throws Exception { - MockTransportService serviceC = build(Settings.builder().put("name", "TS_TEST").build(), version0, null, true); + MockTransportService serviceC = build(Settings.builder().put("name", "TS_TEST").build(), version0, null, true, c -> {}); CountDownLatch receivedLatch = new CountDownLatch(1); CountDownLatch sendResponseLatch = new CountDownLatch(1); Exception ex = new RuntimeException("boom"); @@ -2457,7 +2467,7 @@ public void testTransportProfilesWithPortAndHost() { .put("transport.profiles.some_other_profile.port", "8700-8800") .putArray("transport.profiles.some_other_profile.bind_host", hosts) .putArray("transport.profiles.some_other_profile.publish_host", "_local:ipv4_") - .build(), version0, null, true)) { + .build(), version0, null, true, c -> {})) { serviceC.start(); serviceC.acceptIncomingRequests(); @@ -2612,4 +2622,22 @@ public void testProfilesIncludesDefault() { assertEquals(new HashSet<>(Arrays.asList("default", "test")), profileSettings.stream().map(s -> s.profileName).collect(Collectors .toSet())); } + + public void testChannelCloseWhileConnecting() throws IOException { + final MockTransportService service = buildService("service", version0, clusterSettings, Settings.EMPTY, true, true, this::close); + final TcpTransport underlyingTransport = (TcpTransport) service.getOriginalTransport(); + + final String otherName = "other_service"; + try (TransportService otherService = buildService(otherName, Version.CURRENT, null)) { + final DiscoveryNode node = + new DiscoveryNode(otherName, otherName, otherService.boundAddress().publishAddress(), emptyMap(), emptySet(), version0); + final ConnectTransportException e = + expectThrows(ConnectTransportException.class, () -> underlyingTransport.openConnection(node, null)); + assertThat(e, hasToString(containsString("a channel closed while connecting"))); + } + service.close(); + } + + protected abstract void close(Channel channel); + } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java b/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java index 29ff4219feecb..ef39b435993f3 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java @@ -176,7 +176,9 @@ private void readMessage(MockChannel mockChannel, StreamInput input) throws IOEx } @Override - protected NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile profile, + protected NodeChannels connectToChannels(DiscoveryNode node, + ConnectionProfile profile, + Consumer onChannelOpen, Consumer onChannelClose) throws IOException { final MockChannel[] mockChannels = new MockChannel[1]; final NodeChannels nodeChannels = new NodeChannels(node, mockChannels, LIGHT_PROFILE); // we always use light here @@ -193,6 +195,7 @@ protected NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile p throw new ConnectTransportException(node, "connect_timeout[" + connectTimeout + "]", ex); } MockChannel channel = new MockChannel(socket, address, "none", onChannelClose); + onChannelOpen.accept(channel); channel.loopRead(executor); mockChannels[0] = channel; success = true; diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/NioClient.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/NioClient.java index f846c53abdec6..7b02cfec2d792 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/NioClient.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/NioClient.java @@ -56,7 +56,10 @@ public NioClient(Logger logger, OpenChannels openChannels, Supplier onChannelOpen, Consumer closeListener) throws IOException { boolean allowedToConnect = semaphore.tryAcquire(); if (allowedToConnect == false) { @@ -70,6 +73,7 @@ public boolean connectToChannels(DiscoveryNode node, NioSocketChannel[] channels for (int i = 0; i < channels.length; i++) { SocketSelector selector = selectorSupplier.get(); NioSocketChannel nioSocketChannel = channelFactory.openNioChannel(address, selector, closeListener); + onChannelOpen.accept(nioSocketChannel); openChannels.clientChannelOpened(nioSocketChannel); connections.add(nioSocketChannel); } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/NioTransport.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/NioTransport.java index 606225fd02ad7..21e165267cf21 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/NioTransport.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/NioTransport.java @@ -151,11 +151,12 @@ protected void sendMessage(NioChannel channel, BytesReference reference, ActionL } @Override - protected NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile profile, Consumer onChannelClose) + protected NodeChannels connectToChannels( + DiscoveryNode node, ConnectionProfile profile, Consumer onChannelOpen, Consumer onChannelClose) throws IOException { NioSocketChannel[] channels = new NioSocketChannel[profile.getNumConnections()]; ClientChannelCloseListener closeListener = new ClientChannelCloseListener(onChannelClose); - boolean connected = client.connectToChannels(node, channels, profile.getConnectTimeout(), closeListener); + boolean connected = client.connectToChannels(node, channels, profile.getConnectTimeout(), onChannelOpen, closeListener); if (connected == false) { throw new ElasticsearchException("client is shutdown"); } diff --git a/test/framework/src/test/java/org/elasticsearch/transport/MockTcpTransportTests.java b/test/framework/src/test/java/org/elasticsearch/transport/MockTcpTransportTests.java index b32680d9da466..8a3af7633dd43 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/MockTcpTransportTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/MockTcpTransportTests.java @@ -30,11 +30,15 @@ import org.elasticsearch.test.transport.MockTransportService; import java.io.IOException; +import java.io.UncheckedIOException; import java.util.Collections; +import java.util.function.Consumer; + +public class MockTcpTransportTests extends AbstractSimpleTransportTestCase { -public class MockTcpTransportTests extends AbstractSimpleTransportTestCase { @Override - protected MockTransportService build(Settings settings, Version version, ClusterSettings clusterSettings, boolean doHandshake) { + protected MockTransportService build(Settings settings, Version version, ClusterSettings clusterSettings, boolean doHandshake, + Consumer onChannelOpen) { NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList()); Transport transport = new MockTcpTransport(settings, threadPool, BigArrays.NON_RECYCLING_INSTANCE, new NoneCircuitBreakerService(), namedWriteableRegistry, new NetworkService(Collections.emptyList()), version) { @@ -47,10 +51,25 @@ protected Version executeHandshake(DiscoveryNode node, MockChannel mockChannel, return version.minimumCompatibilityVersion(); } } + + @Override + protected void onChannelOpen(MockChannel mockChannel) { + onChannelOpen.accept(mockChannel); + } }; MockTransportService mockTransportService = MockTransportService.createNewService(Settings.EMPTY, transport, version, threadPool, clusterSettings); mockTransportService.start(); return mockTransportService; } + + @Override + protected void close(MockTcpTransport.MockChannel mockChannel) { + try { + mockChannel.close(); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + } diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/NioClientTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/NioClientTests.java index 4cae51acc83fa..13d3f22443982 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/NioClientTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/NioClientTests.java @@ -84,7 +84,7 @@ public void testCreateConnections() throws IOException, InterruptedException { when(connectFuture1.awaitConnectionComplete(5, TimeUnit.MILLISECONDS)).thenReturn(true); when(connectFuture2.awaitConnectionComplete(5, TimeUnit.MILLISECONDS)).thenReturn(true); - client.connectToChannels(node, channels, TimeValue.timeValueMillis(5), listener); + client.connectToChannels(node, channels, TimeValue.timeValueMillis(5), c -> {}, listener); assertEquals(channel1, channels[0]); assertEquals(channel2, channels[1]); @@ -99,7 +99,7 @@ public void testWithADifferentConnectTimeout() throws IOException, InterruptedEx when(connectFuture1.awaitConnectionComplete(3, TimeUnit.MILLISECONDS)).thenReturn(true); channels = new NioSocketChannel[1]; - client.connectToChannels(node, channels, TimeValue.timeValueMillis(3), listener); + client.connectToChannels(node, channels, TimeValue.timeValueMillis(3), c -> {}, listener); assertEquals(channel1, channels[0]); } @@ -121,7 +121,7 @@ public void testConnectionTimeout() throws IOException, InterruptedException { when(connectFuture2.awaitConnectionComplete(5, TimeUnit.MILLISECONDS)).thenReturn(false); try { - client.connectToChannels(node, channels, TimeValue.timeValueMillis(5), listener); + client.connectToChannels(node, channels, TimeValue.timeValueMillis(5), c -> {}, listener); fail("Should have thrown ConnectTransportException"); } catch (ConnectTransportException e) { assertTrue(e.getMessage().contains("connect_timeout[5ms]")); @@ -149,7 +149,7 @@ public void testConnectionException() throws IOException, InterruptedException { when(connectFuture2.getException()).thenReturn(ioException); try { - client.connectToChannels(node, channels, TimeValue.timeValueMillis(5), listener); + client.connectToChannels(node, channels, TimeValue.timeValueMillis(5), c -> {}, listener); fail("Should have thrown ConnectTransportException"); } catch (ConnectTransportException e) { assertTrue(e.getMessage().contains("connect_exception")); @@ -166,7 +166,7 @@ public void testConnectionException() throws IOException, InterruptedException { public void testCloseDoesNotAllowConnections() throws IOException { client.close(); - assertFalse(client.connectToChannels(node, channels, TimeValue.timeValueMillis(5), listener)); + assertFalse(client.connectToChannels(node, channels, TimeValue.timeValueMillis(5), c -> {}, listener)); for (NioSocketChannel channel : channels) { assertNull(channel); diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/SimpleNioTransportTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/SimpleNioTransportTests.java index 2ba2e4cc02a85..74b43c9a1f034 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/SimpleNioTransportTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/SimpleNioTransportTests.java @@ -41,19 +41,22 @@ import org.elasticsearch.transport.nio.channel.NioChannel; import java.io.IOException; +import java.io.UncheckedIOException; import java.net.InetAddress; import java.net.UnknownHostException; import java.util.Collections; +import java.util.function.Consumer; import static java.util.Collections.emptyMap; import static java.util.Collections.emptySet; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.instanceOf; -public class SimpleNioTransportTests extends AbstractSimpleTransportTestCase { +public class SimpleNioTransportTests extends AbstractSimpleTransportTestCase { public static MockTransportService nioFromThreadPool(Settings settings, ThreadPool threadPool, final Version version, - ClusterSettings clusterSettings, boolean doHandshake) { + ClusterSettings clusterSettings, boolean doHandshake, + Consumer onChannelOpen) { NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList()); NetworkService networkService = new NetworkService(Collections.emptyList()); Transport transport = new NioTransport(settings, threadPool, @@ -70,6 +73,11 @@ protected Version executeHandshake(DiscoveryNode node, NioChannel channel, TimeV } } + @Override + protected void onChannelOpen(NioChannel nioChannel) { + onChannelOpen.accept(nioChannel); + } + @Override protected Version getCurrentVersion() { return version; @@ -87,15 +95,26 @@ protected SocketEventHandler getSocketEventHandler() { } @Override - protected MockTransportService build(Settings settings, Version version, ClusterSettings clusterSettings, boolean doHandshake) { + protected MockTransportService build( + Settings settings, Version version, ClusterSettings clusterSettings, boolean doHandshake, Consumer onChannelOpen) { settings = Settings.builder().put(settings) .put(TcpTransport.PORT.getKey(), "0") .build(); - MockTransportService transportService = nioFromThreadPool(settings, threadPool, version, clusterSettings, doHandshake); + MockTransportService transportService = + nioFromThreadPool(settings, threadPool, version, clusterSettings, doHandshake, onChannelOpen); transportService.start(); return transportService; } + @Override + protected void close(NioChannel nioChannel) { + try { + nioChannel.getCloseFuture().awaitClose(); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + public void testConnectException() throws UnknownHostException { try { serviceA.connectToNode(new DiscoveryNode("C", new TransportAddress(InetAddress.getByName("localhost"), 9876), @@ -120,7 +139,8 @@ public void testBindUnavailableAddress() { .build(); ClusterSettings clusterSettings = new ClusterSettings(settings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); BindTransportException bindTransportException = expectThrows(BindTransportException.class, () -> { - MockTransportService transportService = nioFromThreadPool(settings, threadPool, Version.CURRENT, clusterSettings, true); + MockTransportService transportService = + nioFromThreadPool(settings, threadPool, Version.CURRENT, clusterSettings, true, c -> {}); try { transportService.start(); } finally {