From c26c6c116c137ebb81d6ca711bd20b281753ae4c Mon Sep 17 00:00:00 2001 From: Aman Khare Date: Fri, 23 Feb 2024 17:11:20 +0530 Subject: [PATCH] Abstract AsyncShardFetch cache to allow restructuring for batch mode Signed-off-by: Aman Khare --- .../opensearch/gateway/AsyncShardFetch.java | 244 +------------ .../opensearch/gateway/BaseShardCache.java | 329 ++++++++++++++++++ .../org/opensearch/gateway/ShardCache.java | 80 +++++ 3 files changed, 425 insertions(+), 228 deletions(-) create mode 100644 server/src/main/java/org/opensearch/gateway/BaseShardCache.java create mode 100644 server/src/main/java/org/opensearch/gateway/ShardCache.java diff --git a/server/src/main/java/org/opensearch/gateway/AsyncShardFetch.java b/server/src/main/java/org/opensearch/gateway/AsyncShardFetch.java index 50774f7e0cb1c..4de8447f07421 100644 --- a/server/src/main/java/org/opensearch/gateway/AsyncShardFetch.java +++ b/server/src/main/java/org/opensearch/gateway/AsyncShardFetch.java @@ -86,7 +86,7 @@ public interface Lister, N protected final String type; protected final Map shardAttributesMap; private final Lister, T> action; - private final Map> cache = new HashMap<>(); + private final BaseShardCache cache; private final AtomicLong round = new AtomicLong(); private boolean closed; private final String reroutingKey; @@ -109,6 +109,7 @@ protected AsyncShardFetch( this.action = (Lister, T>) action; this.reroutingKey = "ShardId=[" + shardId.toString() + "]"; enableBatchMode = false; + cache = new ShardCache<>(logger, reroutingKey, type); } /** @@ -134,6 +135,7 @@ protected AsyncShardFetch( this.action = (Lister, T>) action; this.reroutingKey = "BatchID=[" + batchId + "]"; enableBatchMode = true; + cache = new ShardCache<>(logger, reroutingKey, type); } @Override @@ -141,19 +143,6 @@ public synchronized void close() { this.closed = true; } - /** - * Returns the number of async fetches that are currently ongoing. - */ - public synchronized int getNumberOfInFlightFetches() { - int count = 0; - for (NodeEntry nodeEntry : cache.values()) { - if (nodeEntry.isFetching()) { - count++; - } - } - return count; - } - /** * Fetches the data for the relevant shard. If there any ongoing async fetches going on, or new ones have * been initiated by this call, the result will have no data. @@ -187,48 +176,26 @@ public synchronized FetchResult fetchData(DiscoveryNodes nodes, Map> nodesToFetch = findNodesToFetch(cache); - if (nodesToFetch.isEmpty() == false) { + cache.fillShardCacheWithDataNodes(nodes); + List nodeIds = cache.findNodesToFetch(); + if (nodeIds.isEmpty() == false) { // mark all node as fetching and go ahead and async fetch them // use a unique round id to detect stale responses in processAsyncFetch final long fetchingRound = round.incrementAndGet(); - for (NodeEntry nodeEntry : nodesToFetch) { - nodeEntry.markAsFetching(fetchingRound); - } - DiscoveryNode[] discoNodesToFetch = nodesToFetch.stream() - .map(NodeEntry::getNodeId) + cache.markAsFetching(nodeIds, fetchingRound); + DiscoveryNode[] discoNodesToFetch = nodeIds.stream() .map(nodes::get) .toArray(DiscoveryNode[]::new); asyncFetch(discoNodesToFetch, fetchingRound); } // if we are still fetching, return null to indicate it - if (hasAnyNodeFetching(cache)) { + if (cache.hasAnyNodeFetching()) { return new FetchResult<>(null, emptyMap()); } else { // nothing to fetch, yay, build the return value - Map fetchData = new HashMap<>(); Set failedNodes = new HashSet<>(); - for (Iterator>> it = cache.entrySet().iterator(); it.hasNext();) { - Map.Entry> entry = it.next(); - String nodeId = entry.getKey(); - NodeEntry nodeEntry = entry.getValue(); - - DiscoveryNode node = nodes.get(nodeId); - if (node != null) { - if (nodeEntry.isFailed()) { - // if its failed, remove it from the list of nodes, so if this run doesn't work - // we try again next round to fetch it again - it.remove(); - failedNodes.add(nodeEntry.getNodeId()); - } else { - if (nodeEntry.getValue() != null) { - fetchData.put(node, nodeEntry.getValue()); - } - } - } - } + Map fetchData = cache.populateCache(nodes, failedNodes); Map> allIgnoreNodesMap = unmodifiableMap(new HashMap<>(shardToIgnoreNodes)); // clear the nodes to ignore, we had a successful run in fetching everything we can @@ -268,77 +235,18 @@ protected synchronized void processAsyncFetch(List responses, List nodeEntry = cache.get(response.getNode().getId()); - if (nodeEntry != null) { - if (nodeEntry.getFetchingRound() != fetchingRound) { - assert nodeEntry.getFetchingRound() > fetchingRound : "node entries only replaced by newer rounds"; - logger.trace( - "{} received response for [{}] from node {} for an older fetching round (expected: {} but was: {})", - reroutingKey, - nodeEntry.getNodeId(), - type, - nodeEntry.getFetchingRound(), - fetchingRound - ); - } else if (nodeEntry.isFailed()) { - logger.trace( - "{} node {} has failed for [{}] (failure [{}])", - reroutingKey, - nodeEntry.getNodeId(), - type, - nodeEntry.getFailure() - ); - } else { - // if the entry is there, for the right fetching round and not marked as failed already, process it - logger.trace("{} marking {} as done for [{}], result is [{}]", reroutingKey, nodeEntry.getNodeId(), type, response); - nodeEntry.doneFetching(response); - } - } - } + cache.processResponses(responses, fetchingRound); } if (failures != null) { - for (FailedNodeException failure : failures) { - logger.trace("{} processing failure {} for [{}]", reroutingKey, failure, type); - NodeEntry nodeEntry = cache.get(failure.nodeId()); - if (nodeEntry != null) { - if (nodeEntry.getFetchingRound() != fetchingRound) { - assert nodeEntry.getFetchingRound() > fetchingRound : "node entries only replaced by newer rounds"; - logger.trace( - "{} received failure for [{}] from node {} for an older fetching round (expected: {} but was: {})", - reroutingKey, - nodeEntry.getNodeId(), - type, - nodeEntry.getFetchingRound(), - fetchingRound - ); - } else if (nodeEntry.isFailed() == false) { - // if the entry is there, for the right fetching round and not marked as failed already, process it - Throwable unwrappedCause = ExceptionsHelper.unwrapCause(failure.getCause()); - // if the request got rejected or timed out, we need to try it again next time... - if (unwrappedCause instanceof OpenSearchRejectedExecutionException - || unwrappedCause instanceof ReceiveTimeoutTransportException - || unwrappedCause instanceof OpenSearchTimeoutException) { - nodeEntry.restartFetching(); - } else { - logger.warn( - () -> new ParameterizedMessage( - "{}: failed to list shard for {} on node [{}]", - reroutingKey, - type, - failure.nodeId() - ), - failure - ); - nodeEntry.doneFetching(failure.getCause()); - } - } - } - } + cache.processFailures(failures, fetchingRound); } reroute(reroutingKey, "post_response"); } + public int getNumberOfInFlightFetches() { + return cache.getInflightFetches(); + } + /** * Implement this in order to scheduled another round that causes a call to fetch data. */ @@ -351,47 +259,6 @@ synchronized void clearCacheForNode(String nodeId) { cache.remove(nodeId); } - /** - * Fills the shard fetched data with new (data) nodes and a fresh NodeEntry, and removes from - * it nodes that are no longer part of the state. - */ - private void fillShardCacheWithDataNodes(Map> shardCache, DiscoveryNodes nodes) { - // verify that all current data nodes are there - for (final DiscoveryNode node : nodes.getDataNodes().values()) { - if (shardCache.containsKey(node.getId()) == false) { - shardCache.put(node.getId(), new NodeEntry(node.getId())); - } - } - // remove nodes that are not longer part of the data nodes set - shardCache.keySet().removeIf(nodeId -> !nodes.nodeExists(nodeId)); - } - - /** - * Finds all the nodes that need to be fetched. Those are nodes that have no - * data, and are not in fetch mode. - */ - private List> findNodesToFetch(Map> shardCache) { - List> nodesToFetch = new ArrayList<>(); - for (NodeEntry nodeEntry : shardCache.values()) { - if (nodeEntry.hasData() == false && nodeEntry.isFetching() == false) { - nodesToFetch.add(nodeEntry); - } - } - return nodesToFetch; - } - - /** - * Are there any nodes that are fetching data? - */ - private boolean hasAnyNodeFetching(Map> shardCache) { - for (NodeEntry nodeEntry : shardCache.values()) { - if (nodeEntry.isFetching()) { - return true; - } - } - return false; - } - /** * Async fetches data for the provided shard with the set of nodes that need to be fetched from. */ @@ -460,83 +327,4 @@ public void processAllocation(RoutingAllocation allocation) { } } - - /** - * A node entry, holding the state of the fetched data for a specific shard - * for a giving node. - */ - static class NodeEntry { - private final String nodeId; - private boolean fetching; - @Nullable - private T value; - private boolean valueSet; - private Throwable failure; - private long fetchingRound; - - NodeEntry(String nodeId) { - this.nodeId = nodeId; - } - - String getNodeId() { - return this.nodeId; - } - - boolean isFetching() { - return fetching; - } - - void markAsFetching(long fetchingRound) { - assert fetching == false : "double marking a node as fetching"; - this.fetching = true; - this.fetchingRound = fetchingRound; - } - - void doneFetching(T value) { - assert fetching : "setting value but not in fetching mode"; - assert failure == null : "setting value when failure already set"; - this.valueSet = true; - this.value = value; - this.fetching = false; - } - - void doneFetching(Throwable failure) { - assert fetching : "setting value but not in fetching mode"; - assert valueSet == false : "setting failure when already set value"; - assert failure != null : "setting failure can't be null"; - this.failure = failure; - this.fetching = false; - } - - void restartFetching() { - assert fetching : "restarting fetching, but not in fetching mode"; - assert valueSet == false : "value can't be set when restarting fetching"; - assert failure == null : "failure can't be set when restarting fetching"; - this.fetching = false; - } - - boolean isFailed() { - return failure != null; - } - - boolean hasData() { - return valueSet || failure != null; - } - - Throwable getFailure() { - assert hasData() : "getting failure when data has not been fetched"; - return failure; - } - - @Nullable - T getValue() { - assert failure == null : "trying to fetch value, but its marked as failed, check isFailed"; - assert valueSet : "value is not set, hasn't been fetched yet"; - return value; - } - - long getFetchingRound() { - return fetchingRound; - } - } } diff --git a/server/src/main/java/org/opensearch/gateway/BaseShardCache.java b/server/src/main/java/org/opensearch/gateway/BaseShardCache.java new file mode 100644 index 0000000000000..6a54ee3b976a9 --- /dev/null +++ b/server/src/main/java/org/opensearch/gateway/BaseShardCache.java @@ -0,0 +1,329 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.gateway; + +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.ExceptionsHelper; +import org.opensearch.OpenSearchTimeoutException; +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.nodes.BaseNodeResponse; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.node.DiscoveryNodes; +import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.transport.ReceiveTimeoutTransportException; +import reactor.util.annotation.NonNull; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Common functionalities of a cache for storing shard metadata. Cache maintains node level responses. + * Setting up the cache is required from implementation class. While set up, we need 3 functionalities from the user. + * initData : how to initialize an entry of shard cache for a node. + * putData : how to store the response of transport action in the cache. + * getData : how to populate the stored data for any shard allocators like {@link PrimaryShardAllocator} or + * {@link ReplicaShardAllocator} + * + * @param Response type of transport action which has the data to be stored in the cache. + */ +public abstract class BaseShardCache { + private final Logger logger; + private final String logKey; + private final String type; + + protected BaseShardCache(Logger logger, String logKey, String type) { + this.logger = logger; + this.logKey = logKey; + this.type = type; + } + + /** + * Initialize cache's entry for a node. + * + * @param node for which node we need to initialize the cache. + */ + public abstract void initData(DiscoveryNode node); + + /** + * Store the response in the cache from node. + * + * @param node node from which we got the response. + * @param response shard metadata coming from node. + */ + public abstract void putData(DiscoveryNode node, K response); + + /** + * Populate the response from cache. + * + * @param node node for which we need the response. + * @return actual response. + */ + public abstract K getData(DiscoveryNode node); + + @NonNull + public abstract Map getCache(); + + public abstract void clearShardCache(ShardId shardId); + + /** + * Returns the number of fetches that are currently ongoing. + */ + public int getInflightFetches() { + int count = 0; + for (BaseNodeEntry nodeEntry : getCache().values()) { + if (nodeEntry.isFetching()) { + count++; + } + } + return count; + } + + /** + * Fills the shard fetched data with new (data) nodes and a fresh NodeEntry, and removes from + * it nodes that are no longer part of the state. + */ + public void fillShardCacheWithDataNodes(DiscoveryNodes nodes) { + // verify that all current data nodes are there + for (final DiscoveryNode node : nodes.getDataNodes().values()) { + if (getCache().containsKey(node.getId()) == false) { + initData(node); + } + } + // remove nodes that are not longer part of the data nodes set + getCache().keySet().removeIf(nodeId -> !nodes.nodeExists(nodeId)); + } + + /** + * Finds all the nodes that need to be fetched. Those are nodes that have no + * data, and are not in fetch mode. + */ + public List findNodesToFetch() { + List nodesToFetch = new ArrayList<>(); + for (BaseNodeEntry nodeEntry : getCache().values()) { + if (nodeEntry.hasData() == false && nodeEntry.isFetching() == false) { + nodesToFetch.add(nodeEntry.getNodeId()); + } + } + return nodesToFetch; + } + + /** + * Are there any nodes that are fetching data? + */ + public boolean hasAnyNodeFetching() { + for (BaseNodeEntry nodeEntry : getCache().values()) { + if (nodeEntry.isFetching()) { + return true; + } + } + return false; + } + + /** + * Get the data from cache, ignore the failed entries. Use getData functional interface to get the data, as + * different implementations may have different ways to populate the data from cache. + * + * @param nodes Discovery nodes for which we need to return the cache data. + * @param failedNodes return failedNodes with the nodes where fetch has failed. + * @return Map of cache data for every DiscoveryNode. + */ + public Map populateCache(DiscoveryNodes nodes, Set failedNodes) { + Map fetchData = new HashMap<>(); + for (Iterator> it = getCache().entrySet().iterator(); it.hasNext(); ) { + Map.Entry entry = (Map.Entry) it.next(); + String nodeId = entry.getKey(); + BaseNodeEntry nodeEntry = entry.getValue(); + + DiscoveryNode node = nodes.get(nodeId); + if (node != null) { + if (nodeEntry.isFailed()) { + // if its failed, remove it from the list of nodes, so if this run doesn't work + // we try again next round to fetch it again + it.remove(); + failedNodes.add(nodeEntry.getNodeId()); + } else { + K nodeResponse = getData(node); + if (nodeResponse != null) { + fetchData.put(node, nodeResponse); + } + } + } + } + return fetchData; + } + + public void processResponses(List responses, long fetchingRound) { + for (K response : responses) { + BaseNodeEntry nodeEntry = getCache().get(response.getNode().getId()); + if (nodeEntry != null) { + if (validateNodeResponse(nodeEntry, fetchingRound)) { + // if the entry is there, for the right fetching round and not marked as failed already, process it + logger.trace("{} marking {} as done for [{}], result is [{}]", logKey, nodeEntry.getNodeId(), type, + response); + putData(response.getNode(), response); + } + } + } + } + + public boolean validateNodeResponse(BaseNodeEntry nodeEntry, long fetchingRound) { + if (nodeEntry.getFetchingRound() != fetchingRound) { + assert nodeEntry.getFetchingRound() > fetchingRound : "node entries only replaced by newer rounds"; + logger.trace( + "{} received response for [{}] from node {} for an older fetching round (expected: {} but was: {})", + logKey, + nodeEntry.getNodeId(), + type, + nodeEntry.getFetchingRound(), + fetchingRound + ); + return false; + } else if (nodeEntry.isFailed()) { + logger.trace( + "{} node {} has failed for [{}] (failure [{}])", + logKey, + nodeEntry.getNodeId(), + type, + nodeEntry.getFailure() + ); + return false; + } + return true; + } + + public void handleNodeFailure(BaseNodeEntry nodeEntry, FailedNodeException failure, long fetchingRound) { + if (nodeEntry.getFetchingRound() != fetchingRound) { + assert nodeEntry.getFetchingRound() > fetchingRound : "node entries only replaced by newer rounds"; + logger.trace( + "{} received failure for [{}] from node {} for an older fetching round (expected: {} but was: {})", + logKey, + nodeEntry.getNodeId(), + type, + nodeEntry.getFetchingRound(), + fetchingRound + ); + } else if (nodeEntry.isFailed() == false) { + // if the entry is there, for the right fetching round and not marked as failed already, process it + Throwable unwrappedCause = ExceptionsHelper.unwrapCause(failure.getCause()); + // if the request got rejected or timed out, we need to try it again next time... + if (unwrappedCause instanceof OpenSearchRejectedExecutionException + || unwrappedCause instanceof ReceiveTimeoutTransportException + || unwrappedCause instanceof OpenSearchTimeoutException) { + nodeEntry.restartFetching(); + } else { + logger.warn( + () -> new ParameterizedMessage( + "{}: failed to list shard for {} on node [{}]", + logKey, + type, + failure.nodeId() + ), + failure + ); + nodeEntry.doneFetching(failure.getCause()); + } + } + } + + public void processFailures(List failures, long fetchingRound) { + for (FailedNodeException failure : failures) { + logger.trace("{} processing failure {} for [{}]", logKey, failure, type); + BaseNodeEntry nodeEntry = getCache().get(failure.nodeId()); + if (nodeEntry != null) { + handleNodeFailure(nodeEntry, failure, fetchingRound); + } + } + } + + public void remove(String nodeId) { + this.getCache().remove(nodeId); + } + + public void markAsFetching(List nodeIds, long fetchingRound) { + for (String nodeId : nodeIds) { + getCache().get(nodeId).markAsFetching(fetchingRound); + } + } + + + /** + * A node entry, holding only node level fetching related information. + * Actual metadata of shard is stored in child classes. + */ + static class BaseNodeEntry { + private final String nodeId; + private boolean fetching; + private boolean valueSet; + private Throwable failure; + private long fetchingRound; + + BaseNodeEntry(String nodeId) { + this.nodeId = nodeId; + } + + String getNodeId() { + return this.nodeId; + } + + boolean isFetching() { + return fetching; + } + + void markAsFetching(long fetchingRound) { + assert fetching == false : "double marking a node as fetching"; + this.fetching = true; + this.fetchingRound = fetchingRound; + } + + void doneFetching() { + assert fetching : "setting value but not in fetching mode"; + assert failure == null : "setting value when failure already set"; + this.valueSet = true; + this.fetching = false; + } + + void doneFetching(Throwable failure) { + assert fetching : "setting value but not in fetching mode"; + assert valueSet == false : "setting failure when already set value"; + assert failure != null : "setting failure can't be null"; + this.failure = failure; + this.fetching = false; + } + + void restartFetching() { + assert fetching : "restarting fetching, but not in fetching mode"; + assert valueSet == false : "value can't be set when restarting fetching"; + assert failure == null : "failure can't be set when restarting fetching"; + this.fetching = false; + } + + boolean isFailed() { + return failure != null; + } + + boolean hasData() { + return valueSet || failure != null; + } + + Throwable getFailure() { + assert hasData() : "getting failure when data has not been fetched"; + return failure; + } + + long getFetchingRound() { + return fetchingRound; + } + } +} diff --git a/server/src/main/java/org/opensearch/gateway/ShardCache.java b/server/src/main/java/org/opensearch/gateway/ShardCache.java new file mode 100644 index 0000000000000..8526b8cfb84ac --- /dev/null +++ b/server/src/main/java/org/opensearch/gateway/ShardCache.java @@ -0,0 +1,80 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.gateway; + +import org.apache.logging.log4j.Logger; +import org.opensearch.action.support.nodes.BaseNodeResponse; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.Nullable; +import org.opensearch.core.index.shard.ShardId; + +import java.util.HashMap; +import java.util.Map; + +/** + * Cache implementation of transport actions returning single shard related data in the response. + * + * @param Response type of transport action. + */ +public class ShardCache extends BaseShardCache { + + private final Map> cache = new HashMap<>(); + + public ShardCache(Logger logger, String logKey, String type) { + super(logger, logKey, type); + } + + @Override + public void initData(DiscoveryNode node) { + cache.put(node.getId(), new NodeEntry<>(node.getId())); + } + + @Override + public void putData(DiscoveryNode node, K response) { + cache.get(node.getId()).doneFetching(response); + } + + @Override + public K getData(DiscoveryNode node) { + return cache.get(node.getId()).getValue(); + } + + @Override + public Map getCache() { + return cache; + } + + @Override + public void clearShardCache(ShardId shardId) { + cache.clear(); + } + + /** + * A node entry, holding the state of the fetched data for a specific shard + * for a giving node. + */ + static class NodeEntry extends BaseShardCache.BaseNodeEntry { + @Nullable + private T value; + + void doneFetching(T value) { + super.doneFetching(); + this.value = value; + } + + NodeEntry(String nodeId) { + super(nodeId); + } + + T getValue() { + return value; + } + + } +}