Skip to content

Commit

Permalink
Check for closed connection while opening
Browse files Browse the repository at this point in the history
While opening a connection to a node, a channel can subsequently
close. If this happens, a future callback whose purpose is to close all
other channels and disconnect from the node will fire. However, this
future will not be ready to close all the channels because the
connection will not be exposed to the future callback yet. Since this
callback is run once, we will never try to disconnect from this node
again and we will be left with a closed channel. This commit adds a
check that all channels are open before exposing the channel and throws
a general connection exception. In this case, the usual connection retry
logic will take over.
  • Loading branch information
jasontedor committed Oct 9, 2017
1 parent 3898919 commit 2e658c8
Show file tree
Hide file tree
Showing 11 changed files with 145 additions and 35 deletions.
13 changes: 11 additions & 2 deletions core/src/main/java/org/elasticsearch/transport/TcpTransport.java
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,10 @@ public final NodeChannels openConnection(DiscoveryNode node, ConnectionProfile c
}
}
};
nodeChannels = connectToChannels(node, connectionProfile, onClose);
nodeChannels = connectToChannels(node, connectionProfile, this::onChannelOpen, onClose);
if (!Arrays.stream(nodeChannels.channels).allMatch(this::isOpen)) {
throw new ConnectTransportException(node, "a channel closed while connecting");
}
final Channel channel = nodeChannels.getChannels().get(0); // one channel is guaranteed by the connection profile
final TimeValue connectTimeout = connectionProfile.getConnectTimeout() == null ?
defaultConnectionProfile.getConnectTimeout() :
Expand Down Expand Up @@ -617,6 +620,10 @@ public final NodeChannels openConnection(DiscoveryNode node, ConnectionProfile c
}
}

protected void onChannelOpen(final Channel channel) {

}

private void disconnectFromNodeCloseAndNotify(DiscoveryNode node, NodeChannels nodeChannels) {
assert nodeChannels != null : "nodeChannels must not be null";
try {
Expand Down Expand Up @@ -1034,7 +1041,9 @@ protected void innerOnFailure(Exception e) {
*/
protected abstract void sendMessage(Channel channel, BytesReference reference, ActionListener<Channel> listener);

protected abstract NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile connectionProfile,
protected abstract NodeChannels connectToChannels(DiscoveryNode node,
ConnectionProfile connectionProfile,
Consumer<Channel> onChannelOpen,
Consumer<Channel> onChannelClose) throws IOException;

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.channels.Channel;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
Expand Down Expand Up @@ -224,9 +225,16 @@ protected void sendMessage(Object o, BytesReference reference, ActionListener li
}

@Override
protected NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile profile,
protected NodeChannels connectToChannels(DiscoveryNode node,
ConnectionProfile profile,
Consumer onChannelOpen,
Consumer onChannelClose) throws IOException {
return new NodeChannels(node, new Object[profile.getNumConnections()], profile);

final Object[] objects = new Object[profile.getNumConnections()];
for (int i = 0; i < objects.length; i++) {
onChannelOpen.accept(objects[i]);
}
return new NodeChannels(node, objects, profile);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,8 @@ public long getNumOpenServerConnections() {
}

@Override
protected NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile profile, Consumer<Channel> onChannelClose) {
protected NodeChannels connectToChannels(
DiscoveryNode node, ConnectionProfile profile, Consumer<Channel> onChannelOpen, Consumer<Channel> onChannelClose) {
final Channel[] channels = new Channel[profile.getNumConnections()];
final NodeChannels nodeChannels = new NodeChannels(node, channels, profile);
boolean success = false;
Expand Down Expand Up @@ -283,6 +284,7 @@ protected NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile p
if (!future.isSuccess()) {
throw new ConnectTransportException(node, "connect_timeout[" + connectTimeout + "]", future.cause());
}
onChannelOpen.accept(future.channel());
channels[i] = future.channel();
channels[i].closeFuture().addListener(closeListener);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package org.elasticsearch.transport.netty4;

import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import org.elasticsearch.Version;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
Expand All @@ -44,15 +45,16 @@
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.Collections;
import java.util.function.Consumer;

import static java.util.Collections.emptyMap;
import static java.util.Collections.emptySet;
import static org.hamcrest.Matchers.containsString;

public class SimpleNetty4TransportTests extends AbstractSimpleTransportTestCase {
public class SimpleNetty4TransportTests extends AbstractSimpleTransportTestCase<Channel> {

public static MockTransportService nettyFromThreadPool(Settings settings, ThreadPool threadPool, final Version version,
ClusterSettings clusterSettings, boolean doHandshake) {
ClusterSettings clusterSettings, boolean doHandshake, Consumer<Channel> onChannelOpen) {
NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList());
Transport transport = new Netty4Transport(settings, threadPool, new NetworkService(Collections.emptyList()),
BigArrays.NON_RECYCLING_INSTANCE, namedWriteableRegistry, new NoneCircuitBreakerService()) {
Expand All @@ -67,6 +69,11 @@ protected Version executeHandshake(DiscoveryNode node, Channel channel, TimeValu
}
}

@Override
protected void onChannelOpen(Channel channel) {
onChannelOpen.accept(channel);
}

@Override
protected Version getCurrentVersion() {
return version;
Expand All @@ -79,13 +86,21 @@ protected Version getCurrentVersion() {
}

@Override
protected MockTransportService build(Settings settings, Version version, ClusterSettings clusterSettings, boolean doHandshake) {
protected MockTransportService build(
Settings settings, Version version, ClusterSettings clusterSettings, boolean doHandshake, Consumer<Channel> onChannelOpen) {
settings = Settings.builder().put(settings).put(TcpTransport.PORT.getKey(), "0").build();
MockTransportService transportService = nettyFromThreadPool(settings, threadPool, version, clusterSettings, doHandshake);
MockTransportService transportService =
nettyFromThreadPool(settings, threadPool, version, clusterSettings, doHandshake, onChannelOpen);
transportService.start();
return transportService;
}

@Override
protected void close(Channel channel) {
final ChannelFuture future = channel.close();
future.awaitUninterruptibly();
}

public void testConnectException() throws UnknownHostException {
try {
serviceA.connectToNode(new DiscoveryNode("C", new TransportAddress(InetAddress.getByName("localhost"), 9876),
Expand All @@ -108,7 +123,8 @@ public void testBindUnavailableAddress() {
.build();
ClusterSettings clusterSettings = new ClusterSettings(settings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
BindTransportException bindTransportException = expectThrows(BindTransportException.class, () -> {
MockTransportService transportService = nettyFromThreadPool(settings, threadPool, Version.CURRENT, clusterSettings, true);
MockTransportService transportService =
nettyFromThreadPool(settings, threadPool, Version.CURRENT, clusterSettings, true, c -> {});
try {
transportService.start();
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,17 +79,20 @@
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.stream.Collectors;

import static java.util.Collections.emptyMap;
import static java.util.Collections.emptySet;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasToString;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.startsWith;

public abstract class AbstractSimpleTransportTestCase extends ESTestCase {
public abstract class AbstractSimpleTransportTestCase<Channel> extends ESTestCase {

protected ThreadPool threadPool;
// we use always a non-alpha or beta version here otherwise minimumCompatibilityVersion will be different for the two used versions
Expand All @@ -105,7 +108,8 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase {
protected volatile DiscoveryNode nodeB;
protected volatile MockTransportService serviceB;

protected abstract MockTransportService build(Settings settings, Version version, ClusterSettings clusterSettings, boolean doHandshake);
protected abstract MockTransportService build(
Settings settings, Version version, ClusterSettings clusterSettings, boolean doHandshake, Consumer<Channel> onChannelOpen);

@Override
@Before
Expand Down Expand Up @@ -146,6 +150,12 @@ public void onNodeDisconnected(DiscoveryNode node) {

private MockTransportService buildService(final String name, final Version version, ClusterSettings clusterSettings,
Settings settings, boolean acceptRequests, boolean doHandshake) {
return buildService(name, version, clusterSettings, settings, acceptRequests, doHandshake, c -> {});
}

private MockTransportService buildService(final String name, final Version version, ClusterSettings clusterSettings,
Settings settings, boolean acceptRequests, boolean doHandshake,
Consumer<Channel> onChannelOpen) {
MockTransportService service = build(
Settings.builder()
.put(settings)
Expand All @@ -154,7 +164,7 @@ private MockTransportService buildService(final String name, final Version versi
.put(TransportService.TRACE_LOG_EXCLUDE_SETTING.getKey(), "NOTHING")
.build(),
version,
clusterSettings, doHandshake);
clusterSettings, doHandshake, onChannelOpen);
if (acceptRequests) {
service.acceptIncomingRequests();
}
Expand Down Expand Up @@ -1692,7 +1702,7 @@ public void testSendRandomRequests() throws InterruptedException {
.put(TransportService.TRACE_LOG_EXCLUDE_SETTING.getKey(), "NOTHING")
.build(),
version0,
null, true);
null, true, c -> {});
DiscoveryNode nodeC = serviceC.getLocalNode();
serviceC.acceptIncomingRequests();

Expand Down Expand Up @@ -2125,7 +2135,7 @@ public String executor() {
public void testHandlerIsInvokedOnConnectionClose() throws IOException, InterruptedException {
List<String> executors = new ArrayList<>(ThreadPool.THREAD_POOL_TYPES.keySet());
CollectionUtil.timSort(executors); // makes sure it's reproducible
TransportService serviceC = build(Settings.builder().put("name", "TS_TEST").build(), version0, null, true);
TransportService serviceC = build(Settings.builder().put("name", "TS_TEST").build(), version0, null, true, c -> {});
serviceC.registerRequestHandler("action", TestRequest::new, ThreadPool.Names.SAME,
(request, channel) -> {
// do nothing
Expand Down Expand Up @@ -2183,7 +2193,7 @@ public String executor() {
}

public void testConcurrentDisconnectOnNonPublishedConnection() throws IOException, InterruptedException {
MockTransportService serviceC = build(Settings.builder().put("name", "TS_TEST").build(), version0, null, true);
MockTransportService serviceC = build(Settings.builder().put("name", "TS_TEST").build(), version0, null, true, c -> {});
CountDownLatch receivedLatch = new CountDownLatch(1);
CountDownLatch sendResponseLatch = new CountDownLatch(1);
serviceC.registerRequestHandler("action", TestRequest::new, ThreadPool.Names.SAME,
Expand Down Expand Up @@ -2251,7 +2261,7 @@ public String executor() {
}

public void testTransportStats() throws Exception {
MockTransportService serviceC = build(Settings.builder().put("name", "TS_TEST").build(), version0, null, true);
MockTransportService serviceC = build(Settings.builder().put("name", "TS_TEST").build(), version0, null, true, c -> {});
CountDownLatch receivedLatch = new CountDownLatch(1);
CountDownLatch sendResponseLatch = new CountDownLatch(1);
serviceB.registerRequestHandler("action", TestRequest::new, ThreadPool.Names.SAME,
Expand Down Expand Up @@ -2344,7 +2354,7 @@ public String executor() {
}

public void testTransportStatsWithException() throws Exception {
MockTransportService serviceC = build(Settings.builder().put("name", "TS_TEST").build(), version0, null, true);
MockTransportService serviceC = build(Settings.builder().put("name", "TS_TEST").build(), version0, null, true, c -> {});
CountDownLatch receivedLatch = new CountDownLatch(1);
CountDownLatch sendResponseLatch = new CountDownLatch(1);
Exception ex = new RuntimeException("boom");
Expand Down Expand Up @@ -2457,7 +2467,7 @@ public void testTransportProfilesWithPortAndHost() {
.put("transport.profiles.some_other_profile.port", "8700-8800")
.putArray("transport.profiles.some_other_profile.bind_host", hosts)
.putArray("transport.profiles.some_other_profile.publish_host", "_local:ipv4_")
.build(), version0, null, true)) {
.build(), version0, null, true, c -> {})) {

serviceC.start();
serviceC.acceptIncomingRequests();
Expand Down Expand Up @@ -2612,4 +2622,22 @@ public void testProfilesIncludesDefault() {
assertEquals(new HashSet<>(Arrays.asList("default", "test")), profileSettings.stream().map(s -> s.profileName).collect(Collectors
.toSet()));
}

public void testChannelCloseWhileConnecting() throws IOException {
final MockTransportService service = buildService("service", version0, clusterSettings, Settings.EMPTY, true, true, this::close);
final TcpTransport underlyingTransport = (TcpTransport) service.getOriginalTransport();

final String otherName = "other_service";
try (TransportService otherService = buildService(otherName, Version.CURRENT, null)) {
final DiscoveryNode node =
new DiscoveryNode(otherName, otherName, otherService.boundAddress().publishAddress(), emptyMap(), emptySet(), version0);
final ConnectTransportException e =
expectThrows(ConnectTransportException.class, () -> underlyingTransport.openConnection(node, null));
assertThat(e, hasToString(containsString("a channel closed while connecting")));
}
service.close();
}

protected abstract void close(Channel channel);

}
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,9 @@ private void readMessage(MockChannel mockChannel, StreamInput input) throws IOEx
}

@Override
protected NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile profile,
protected NodeChannels connectToChannels(DiscoveryNode node,
ConnectionProfile profile,
Consumer<MockChannel> onChannelOpen,
Consumer<MockChannel> onChannelClose) throws IOException {
final MockChannel[] mockChannels = new MockChannel[1];
final NodeChannels nodeChannels = new NodeChannels(node, mockChannels, LIGHT_PROFILE); // we always use light here
Expand All @@ -193,6 +195,7 @@ protected NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile p
throw new ConnectTransportException(node, "connect_timeout[" + connectTimeout + "]", ex);
}
MockChannel channel = new MockChannel(socket, address, "none", onChannelClose);
onChannelOpen.accept(channel);
channel.loopRead(executor);
mockChannels[0] = channel;
success = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,10 @@ public NioClient(Logger logger, OpenChannels openChannels, Supplier<SocketSelect
this.channelFactory = channelFactory;
}

public boolean connectToChannels(DiscoveryNode node, NioSocketChannel[] channels, TimeValue connectTimeout,
public boolean connectToChannels(DiscoveryNode node,
NioSocketChannel[] channels,
TimeValue connectTimeout,
Consumer<NioChannel> onChannelOpen,
Consumer<NioChannel> closeListener) throws IOException {
boolean allowedToConnect = semaphore.tryAcquire();
if (allowedToConnect == false) {
Expand All @@ -70,6 +73,7 @@ public boolean connectToChannels(DiscoveryNode node, NioSocketChannel[] channels
for (int i = 0; i < channels.length; i++) {
SocketSelector selector = selectorSupplier.get();
NioSocketChannel nioSocketChannel = channelFactory.openNioChannel(address, selector, closeListener);
onChannelOpen.accept(nioSocketChannel);
openChannels.clientChannelOpened(nioSocketChannel);
connections.add(nioSocketChannel);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,12 @@ protected void sendMessage(NioChannel channel, BytesReference reference, ActionL
}

@Override
protected NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile profile, Consumer<NioChannel> onChannelClose)
protected NodeChannels connectToChannels(
DiscoveryNode node, ConnectionProfile profile, Consumer<NioChannel> onChannelOpen, Consumer<NioChannel> onChannelClose)
throws IOException {
NioSocketChannel[] channels = new NioSocketChannel[profile.getNumConnections()];
ClientChannelCloseListener closeListener = new ClientChannelCloseListener(onChannelClose);
boolean connected = client.connectToChannels(node, channels, profile.getConnectTimeout(), closeListener);
boolean connected = client.connectToChannels(node, channels, profile.getConnectTimeout(), onChannelOpen, closeListener);
if (connected == false) {
throw new ElasticsearchException("client is shutdown");
}
Expand Down
Loading

0 comments on commit 2e658c8

Please sign in to comment.