diff --git a/core/src/main/java/org/elasticsearch/transport/RemoteClusterConnection.java b/core/src/main/java/org/elasticsearch/transport/RemoteClusterConnection.java index 2b16c26931b86..a3c7613f821ba 100644 --- a/core/src/main/java/org/elasticsearch/transport/RemoteClusterConnection.java +++ b/core/src/main/java/org/elasticsearch/transport/RemoteClusterConnection.java @@ -33,7 +33,6 @@ import org.elasticsearch.action.admin.cluster.state.ClusterStateAction; import org.elasticsearch.action.admin.cluster.state.ClusterStateRequest; import org.elasticsearch.action.admin.cluster.state.ClusterStateResponse; -import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.common.component.AbstractComponent; @@ -56,7 +55,6 @@ import java.util.Set; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; -import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.Semaphore; @@ -81,8 +79,7 @@ final class RemoteClusterConnection extends AbstractComponent implements Transpo private final TransportService transportService; private final ConnectionProfile remoteProfile; - private final Set connectedNodes = Collections.newSetFromMap(new ConcurrentHashMap<>()); - private final Supplier nodeSupplier; + private final ConnectedNodes connectedNodes; private final String clusterAlias; private final int maxNumRemoteConnections; private final Predicate nodePredicate; @@ -114,19 +111,7 @@ final class RemoteClusterConnection extends AbstractComponent implements Transpo TransportRequestOptions.Type.STATE, TransportRequestOptions.Type.RECOVERY); remoteProfile = builder.build(); - nodeSupplier = new Supplier() { - private volatile Iterator current; - @Override - public DiscoveryNode get() { - if (current == null || current.hasNext() == false) { - current = connectedNodes.iterator(); - if (current.hasNext() == false) { - throw new IllegalStateException("No node available for cluster: " + clusterAlias + " nodes: " + connectedNodes); - } - } - return current.next(); - } - }; + connectedNodes = new ConnectedNodes(clusterAlias); this.seedNodes = Collections.unmodifiableList(seedNodes); this.connectHandler = new ConnectHandler(); transportService.addConnectionListener(this); @@ -154,7 +139,7 @@ public void onNodeDisconnected(DiscoveryNode node) { */ public void fetchSearchShards(ClusterSearchShardsRequest searchRequest, ActionListener listener) { - if (connectedNodes.isEmpty()) { + if (connectedNodes.size() == 0) { // just in case if we are not connected for some reason we try to connect and if we fail we have to notify the listener // this will cause some back pressure on the search end and eventually will cause rejections but that's fine // we can't proceed with a search on a cluster level. @@ -171,7 +156,7 @@ public void fetchSearchShards(ClusterSearchShardsRequest searchRequest, * will invoke the listener immediately. */ public void ensureConnected(ActionListener voidActionListener) { - if (connectedNodes.isEmpty()) { + if (connectedNodes.size() == 0) { connectHandler.connect(voidActionListener); } else { voidActionListener.onResponse(null); @@ -180,7 +165,7 @@ public void ensureConnected(ActionListener voidActionListener) { private void fetchShardsInternal(ClusterSearchShardsRequest searchShardsRequest, final ActionListener listener) { - final DiscoveryNode node = nodeSupplier.get(); + final DiscoveryNode node = connectedNodes.get(); transportService.sendRequest(node, ClusterSearchShardsAction.NAME, searchShardsRequest, new TransportResponseHandler() { @@ -211,7 +196,7 @@ public String executor() { * given node. */ Transport.Connection getConnection(DiscoveryNode remoteClusterNode) { - DiscoveryNode discoveryNode = nodeSupplier.get(); + DiscoveryNode discoveryNode = connectedNodes.get(); Transport.Connection connection = transportService.getConnection(discoveryNode); return new Transport.Connection() { @Override @@ -234,12 +219,11 @@ public void close() throws IOException { } Transport.Connection getConnection() { - DiscoveryNode discoveryNode = nodeSupplier.get(); + DiscoveryNode discoveryNode = connectedNodes.get(); return transportService.getConnection(discoveryNode); } - - @Override + @Override public void close() throws IOException { connectHandler.close(); } @@ -534,12 +518,19 @@ boolean isNodeConnected(final DiscoveryNode node) { return connectedNodes.contains(node); } + DiscoveryNode getConnectedNode() { + return connectedNodes.get(); + } + + void addConnectedNode(DiscoveryNode node) { + connectedNodes.add(node); + } /** * Fetches connection info for this connection */ public void getConnectionInfo(ActionListener listener) { - final Optional anyNode = connectedNodes.stream().findAny(); + final Optional anyNode = connectedNodes.getAny(); if (anyNode.isPresent() == false) { // not connected we return immediately RemoteConnectionInfo remoteConnectionStats = new RemoteConnectionInfo(clusterAlias, @@ -601,4 +592,68 @@ public String executor() { int getNumNodesConnected() { return connectedNodes.size(); } + + private static class ConnectedNodes implements Supplier { + + private final Set nodeSet = new HashSet<>(); + private final String clusterAlias; + + private Iterator currentIterator = null; + + private ConnectedNodes(String clusterAlias) { + this.clusterAlias = clusterAlias; + } + + @Override + public synchronized DiscoveryNode get() { + ensureIteratorAvailable(); + if (currentIterator.hasNext()) { + return currentIterator.next(); + } else { + throw new IllegalStateException("No node available for cluster: " + clusterAlias); + } + } + + synchronized boolean remove(DiscoveryNode node) { + final boolean setRemoval = nodeSet.remove(node); + if (setRemoval) { + currentIterator = null; + } + return setRemoval; + } + + synchronized boolean add(DiscoveryNode node) { + final boolean added = nodeSet.add(node); + if (added) { + currentIterator = null; + } + return added; + } + + synchronized int size() { + return nodeSet.size(); + } + + synchronized boolean contains(DiscoveryNode node) { + return nodeSet.contains(node); + } + + synchronized Optional getAny() { + ensureIteratorAvailable(); + if (currentIterator.hasNext()) { + return Optional.of(currentIterator.next()); + } else { + return Optional.empty(); + } + } + + private synchronized void ensureIteratorAvailable() { + if (currentIterator == null) { + currentIterator = nodeSet.iterator(); + } else if (currentIterator.hasNext() == false && nodeSet.isEmpty() == false) { + // iterator rollover + currentIterator = nodeSet.iterator(); + } + } + } } diff --git a/core/src/test/java/org/elasticsearch/transport/RemoteClusterConnectionTests.java b/core/src/test/java/org/elasticsearch/transport/RemoteClusterConnectionTests.java index 183d9c0a4e5fb..e1ca2d98e527e 100644 --- a/core/src/test/java/org/elasticsearch/transport/RemoteClusterConnectionTests.java +++ b/core/src/test/java/org/elasticsearch/transport/RemoteClusterConnectionTests.java @@ -19,6 +19,7 @@ package org.elasticsearch.transport; import org.apache.lucene.store.AlreadyClosedException; +import org.apache.lucene.util.IOUtils; import org.elasticsearch.Build; import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; @@ -56,11 +57,6 @@ import org.elasticsearch.test.transport.MockTransportService; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.transport.RemoteClusterConnection; -import org.elasticsearch.transport.RemoteConnectionInfo; -import org.elasticsearch.transport.RemoteTransportException; -import org.elasticsearch.transport.TransportConnectionListener; -import org.elasticsearch.transport.TransportService; import java.io.IOException; import java.net.InetAddress; @@ -78,6 +74,7 @@ import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import static java.util.Collections.emptyMap; @@ -787,4 +784,89 @@ public void onFailure(Exception e) { } } } + + public void testConnectedNodesConcurrentAccess() throws IOException, InterruptedException { + List knownNodes = new CopyOnWriteArrayList<>(); + List discoverableTransports = new CopyOnWriteArrayList<>(); + try { + final int numDiscoverableNodes = randomIntBetween(5, 20); + List discoverableNodes = new ArrayList<>(numDiscoverableNodes); + for (int i = 0; i < numDiscoverableNodes; i++ ) { + MockTransportService transportService = startTransport("discoverable_node" + i, knownNodes, Version.CURRENT); + discoverableNodes.add(transportService.getLocalDiscoNode()); + discoverableTransports.add(transportService); + } + + List seedNodes = randomSubsetOf(discoverableNodes); + Collections.shuffle(seedNodes, random()); + + try (MockTransportService service = MockTransportService.createNewService(Settings.EMPTY, Version.CURRENT, threadPool, null)) { + service.start(); + service.acceptIncomingRequests(); + try (RemoteClusterConnection connection = new RemoteClusterConnection(Settings.EMPTY, "test-cluster", + seedNodes, service, Integer.MAX_VALUE, n -> true)) { + final int numGetThreads = randomIntBetween(4, 10); + final Thread[] getThreads = new Thread[numGetThreads]; + final int numModifyingThreads = randomIntBetween(4, 10); + final Thread[] modifyingThreads = new Thread[numModifyingThreads]; + CyclicBarrier barrier = new CyclicBarrier(numGetThreads + numModifyingThreads); + for (int i = 0; i < getThreads.length; i++) { + final int numGetCalls = randomIntBetween(1000, 10000); + getThreads[i] = new Thread(() -> { + try { + barrier.await(); + for (int j = 0; j < numGetCalls; j++) { + try { + DiscoveryNode node = connection.getConnectedNode(); + assertNotNull(node); + } catch (IllegalStateException e) { + if (e.getMessage().startsWith("No node available for cluster:") == false) { + throw e; + } + } + } + } catch (Exception ex) { + throw new AssertionError(ex); + } + }); + getThreads[i].start(); + } + + final AtomicInteger counter = new AtomicInteger(); + for (int i = 0; i < modifyingThreads.length; i++) { + final int numDisconnects = randomIntBetween(5, 10); + modifyingThreads[i] = new Thread(() -> { + try { + barrier.await(); + for (int j = 0; j < numDisconnects; j++) { + if (randomBoolean()) { + MockTransportService transportService = + startTransport("discoverable_node_added" + counter.incrementAndGet(), knownNodes, + Version.CURRENT); + discoverableTransports.add(transportService); + connection.addConnectedNode(transportService.getLocalDiscoNode()); + } else { + DiscoveryNode node = randomFrom(discoverableNodes); + connection.onNodeDisconnected(node); + } + } + } catch (Exception ex) { + throw new AssertionError(ex); + } + }); + modifyingThreads[i].start(); + } + + for (Thread thread : getThreads) { + thread.join(); + } + for (Thread thread : modifyingThreads) { + thread.join(); + } + } + } + } finally { + IOUtils.closeWhileHandlingException(discoverableTransports); + } + } }