Skip to content

Commit

Permalink
Fix race condition in RemoteClusterConnection node supplier (#25432)
Browse files Browse the repository at this point in the history
This commit fixes a race condition in the node supplier used by the RemoteClusterConnection. The
node supplier stores an iterator over a set backed by a ConcurrentHashMap, but the get operation
of the supplier uses multiple methods of the iterator and is suceptible to a race between the
calls to hasNext() and next(). The test in this commit fails under the old implementation with a
NoSuchElementException. This commit adds a wrapper object over a set and a iterator, with all methods
being synchronized to avoid races. Modifications to the set result in the iterator being set to null
and the next retrieval creates a new iterator.
  • Loading branch information
jaymode committed Jun 28, 2017
1 parent 4ca1567 commit 415b81e
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import org.elasticsearch.action.admin.cluster.state.ClusterStateAction;
import org.elasticsearch.action.admin.cluster.state.ClusterStateRequest;
import org.elasticsearch.action.admin.cluster.state.ClusterStateResponse;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.common.component.AbstractComponent;
Expand All @@ -56,7 +55,6 @@
import java.util.Set;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.Semaphore;
Expand All @@ -81,8 +79,7 @@ final class RemoteClusterConnection extends AbstractComponent implements Transpo

private final TransportService transportService;
private final ConnectionProfile remoteProfile;
private final Set<DiscoveryNode> connectedNodes = Collections.newSetFromMap(new ConcurrentHashMap<>());
private final Supplier<DiscoveryNode> nodeSupplier;
private final ConnectedNodes connectedNodes;
private final String clusterAlias;
private final int maxNumRemoteConnections;
private final Predicate<DiscoveryNode> nodePredicate;
Expand Down Expand Up @@ -114,19 +111,7 @@ final class RemoteClusterConnection extends AbstractComponent implements Transpo
TransportRequestOptions.Type.STATE,
TransportRequestOptions.Type.RECOVERY);
remoteProfile = builder.build();
nodeSupplier = new Supplier<DiscoveryNode>() {
private volatile Iterator<DiscoveryNode> current;
@Override
public DiscoveryNode get() {
if (current == null || current.hasNext() == false) {
current = connectedNodes.iterator();
if (current.hasNext() == false) {
throw new IllegalStateException("No node available for cluster: " + clusterAlias + " nodes: " + connectedNodes);
}
}
return current.next();
}
};
connectedNodes = new ConnectedNodes(clusterAlias);
this.seedNodes = Collections.unmodifiableList(seedNodes);
this.connectHandler = new ConnectHandler();
transportService.addConnectionListener(this);
Expand Down Expand Up @@ -154,7 +139,7 @@ public void onNodeDisconnected(DiscoveryNode node) {
*/
public void fetchSearchShards(ClusterSearchShardsRequest searchRequest,
ActionListener<ClusterSearchShardsResponse> listener) {
if (connectedNodes.isEmpty()) {
if (connectedNodes.size() == 0) {
// just in case if we are not connected for some reason we try to connect and if we fail we have to notify the listener
// this will cause some back pressure on the search end and eventually will cause rejections but that's fine
// we can't proceed with a search on a cluster level.
Expand All @@ -171,7 +156,7 @@ public void fetchSearchShards(ClusterSearchShardsRequest searchRequest,
* will invoke the listener immediately.
*/
public void ensureConnected(ActionListener<Void> voidActionListener) {
if (connectedNodes.isEmpty()) {
if (connectedNodes.size() == 0) {
connectHandler.connect(voidActionListener);
} else {
voidActionListener.onResponse(null);
Expand All @@ -180,7 +165,7 @@ public void ensureConnected(ActionListener<Void> voidActionListener) {

private void fetchShardsInternal(ClusterSearchShardsRequest searchShardsRequest,
final ActionListener<ClusterSearchShardsResponse> listener) {
final DiscoveryNode node = nodeSupplier.get();
final DiscoveryNode node = connectedNodes.get();
transportService.sendRequest(node, ClusterSearchShardsAction.NAME, searchShardsRequest,
new TransportResponseHandler<ClusterSearchShardsResponse>() {

Expand Down Expand Up @@ -211,7 +196,7 @@ public String executor() {
* given node.
*/
Transport.Connection getConnection(DiscoveryNode remoteClusterNode) {
DiscoveryNode discoveryNode = nodeSupplier.get();
DiscoveryNode discoveryNode = connectedNodes.get();
Transport.Connection connection = transportService.getConnection(discoveryNode);
return new Transport.Connection() {
@Override
Expand All @@ -234,12 +219,11 @@ public void close() throws IOException {
}

Transport.Connection getConnection() {
DiscoveryNode discoveryNode = nodeSupplier.get();
DiscoveryNode discoveryNode = connectedNodes.get();
return transportService.getConnection(discoveryNode);
}


@Override
@Override
public void close() throws IOException {
connectHandler.close();
}
Expand Down Expand Up @@ -534,12 +518,19 @@ boolean isNodeConnected(final DiscoveryNode node) {
return connectedNodes.contains(node);
}

DiscoveryNode getConnectedNode() {
return connectedNodes.get();
}

void addConnectedNode(DiscoveryNode node) {
connectedNodes.add(node);
}

/**
* Fetches connection info for this connection
*/
public void getConnectionInfo(ActionListener<RemoteConnectionInfo> listener) {
final Optional<DiscoveryNode> anyNode = connectedNodes.stream().findAny();
final Optional<DiscoveryNode> anyNode = connectedNodes.getAny();
if (anyNode.isPresent() == false) {
// not connected we return immediately
RemoteConnectionInfo remoteConnectionStats = new RemoteConnectionInfo(clusterAlias,
Expand Down Expand Up @@ -601,4 +592,68 @@ public String executor() {
int getNumNodesConnected() {
return connectedNodes.size();
}

private static class ConnectedNodes implements Supplier<DiscoveryNode> {

private final Set<DiscoveryNode> nodeSet = new HashSet<>();
private final String clusterAlias;

private Iterator<DiscoveryNode> currentIterator = null;

private ConnectedNodes(String clusterAlias) {
this.clusterAlias = clusterAlias;
}

@Override
public synchronized DiscoveryNode get() {
ensureIteratorAvailable();
if (currentIterator.hasNext()) {
return currentIterator.next();
} else {
throw new IllegalStateException("No node available for cluster: " + clusterAlias);
}
}

synchronized boolean remove(DiscoveryNode node) {
final boolean setRemoval = nodeSet.remove(node);
if (setRemoval) {
currentIterator = null;
}
return setRemoval;
}

synchronized boolean add(DiscoveryNode node) {
final boolean added = nodeSet.add(node);
if (added) {
currentIterator = null;
}
return added;
}

synchronized int size() {
return nodeSet.size();
}

synchronized boolean contains(DiscoveryNode node) {
return nodeSet.contains(node);
}

synchronized Optional<DiscoveryNode> getAny() {
ensureIteratorAvailable();
if (currentIterator.hasNext()) {
return Optional.of(currentIterator.next());
} else {
return Optional.empty();
}
}

private synchronized void ensureIteratorAvailable() {
if (currentIterator == null) {
currentIterator = nodeSet.iterator();
} else if (currentIterator.hasNext() == false && nodeSet.isEmpty() == false) {
// iterator rollover
currentIterator = nodeSet.iterator();
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.elasticsearch.transport;

import org.apache.lucene.store.AlreadyClosedException;
import org.apache.lucene.util.IOUtils;
import org.elasticsearch.Build;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
Expand Down Expand Up @@ -56,11 +57,6 @@
import org.elasticsearch.test.transport.MockTransportService;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.RemoteClusterConnection;
import org.elasticsearch.transport.RemoteConnectionInfo;
import org.elasticsearch.transport.RemoteTransportException;
import org.elasticsearch.transport.TransportConnectionListener;
import org.elasticsearch.transport.TransportService;

import java.io.IOException;
import java.net.InetAddress;
Expand All @@ -78,6 +74,7 @@
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

import static java.util.Collections.emptyMap;
Expand Down Expand Up @@ -787,4 +784,89 @@ public void onFailure(Exception e) {
}
}
}

public void testConnectedNodesConcurrentAccess() throws IOException, InterruptedException {
List<DiscoveryNode> knownNodes = new CopyOnWriteArrayList<>();
List<MockTransportService> discoverableTransports = new CopyOnWriteArrayList<>();
try {
final int numDiscoverableNodes = randomIntBetween(5, 20);
List<DiscoveryNode> discoverableNodes = new ArrayList<>(numDiscoverableNodes);
for (int i = 0; i < numDiscoverableNodes; i++ ) {
MockTransportService transportService = startTransport("discoverable_node" + i, knownNodes, Version.CURRENT);
discoverableNodes.add(transportService.getLocalDiscoNode());
discoverableTransports.add(transportService);
}

List<DiscoveryNode> seedNodes = randomSubsetOf(discoverableNodes);
Collections.shuffle(seedNodes, random());

try (MockTransportService service = MockTransportService.createNewService(Settings.EMPTY, Version.CURRENT, threadPool, null)) {
service.start();
service.acceptIncomingRequests();
try (RemoteClusterConnection connection = new RemoteClusterConnection(Settings.EMPTY, "test-cluster",
seedNodes, service, Integer.MAX_VALUE, n -> true)) {
final int numGetThreads = randomIntBetween(4, 10);
final Thread[] getThreads = new Thread[numGetThreads];
final int numModifyingThreads = randomIntBetween(4, 10);
final Thread[] modifyingThreads = new Thread[numModifyingThreads];
CyclicBarrier barrier = new CyclicBarrier(numGetThreads + numModifyingThreads);
for (int i = 0; i < getThreads.length; i++) {
final int numGetCalls = randomIntBetween(1000, 10000);
getThreads[i] = new Thread(() -> {
try {
barrier.await();
for (int j = 0; j < numGetCalls; j++) {
try {
DiscoveryNode node = connection.getConnectedNode();
assertNotNull(node);
} catch (IllegalStateException e) {
if (e.getMessage().startsWith("No node available for cluster:") == false) {
throw e;
}
}
}
} catch (Exception ex) {
throw new AssertionError(ex);
}
});
getThreads[i].start();
}

final AtomicInteger counter = new AtomicInteger();
for (int i = 0; i < modifyingThreads.length; i++) {
final int numDisconnects = randomIntBetween(5, 10);
modifyingThreads[i] = new Thread(() -> {
try {
barrier.await();
for (int j = 0; j < numDisconnects; j++) {
if (randomBoolean()) {
MockTransportService transportService =
startTransport("discoverable_node_added" + counter.incrementAndGet(), knownNodes,
Version.CURRENT);
discoverableTransports.add(transportService);
connection.addConnectedNode(transportService.getLocalDiscoNode());
} else {
DiscoveryNode node = randomFrom(discoverableNodes);
connection.onNodeDisconnected(node);
}
}
} catch (Exception ex) {
throw new AssertionError(ex);
}
});
modifyingThreads[i].start();
}

for (Thread thread : getThreads) {
thread.join();
}
for (Thread thread : modifyingThreads) {
thread.join();
}
}
}
} finally {
IOUtils.closeWhileHandlingException(discoverableTransports);
}
}
}

0 comments on commit 415b81e

Please sign in to comment.