Skip to content

Commit

Permalink
Abstract AsyncShardFetch cache to allow restructuring for batch mode
Browse files Browse the repository at this point in the history
Signed-off-by: Aman Khare <[email protected]>
  • Loading branch information
Aman Khare committed Feb 23, 2024
1 parent 87ac374 commit c26c6c1
Show file tree
Hide file tree
Showing 3 changed files with 425 additions and 228 deletions.
244 changes: 16 additions & 228 deletions server/src/main/java/org/opensearch/gateway/AsyncShardFetch.java
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ public interface Lister<NodesResponse extends BaseNodesResponse<NodeResponse>, N
protected final String type;
protected final Map<ShardId, ShardAttributes> shardAttributesMap;
private final Lister<BaseNodesResponse<T>, T> action;
private final Map<String, NodeEntry<T>> cache = new HashMap<>();
private final BaseShardCache<T> cache;
private final AtomicLong round = new AtomicLong();
private boolean closed;
private final String reroutingKey;
Expand All @@ -109,6 +109,7 @@ protected AsyncShardFetch(
this.action = (Lister<BaseNodesResponse<T>, T>) action;
this.reroutingKey = "ShardId=[" + shardId.toString() + "]";
enableBatchMode = false;
cache = new ShardCache<>(logger, reroutingKey, type);
}

/**
Expand All @@ -134,26 +135,14 @@ protected AsyncShardFetch(
this.action = (Lister<BaseNodesResponse<T>, T>) action;
this.reroutingKey = "BatchID=[" + batchId + "]";
enableBatchMode = true;
cache = new ShardCache<>(logger, reroutingKey, type);
}

@Override
public synchronized void close() {
this.closed = true;
}

/**
* Returns the number of async fetches that are currently ongoing.
*/
public synchronized int getNumberOfInFlightFetches() {
int count = 0;
for (NodeEntry<T> nodeEntry : cache.values()) {
if (nodeEntry.isFetching()) {
count++;
}
}
return count;
}

/**
* Fetches the data for the relevant shard. If there any ongoing async fetches going on, or new ones have
* been initiated by this call, the result will have no data.
Expand Down Expand Up @@ -187,48 +176,26 @@ public synchronized FetchResult<T> fetchData(DiscoveryNodes nodes, Map<ShardId,
shardToIgnoreNodes.put(ignoreNodesEntry.getKey(), ignoreNodesSet);
}

fillShardCacheWithDataNodes(cache, nodes);
List<NodeEntry<T>> nodesToFetch = findNodesToFetch(cache);
if (nodesToFetch.isEmpty() == false) {
cache.fillShardCacheWithDataNodes(nodes);
List<String> nodeIds = cache.findNodesToFetch();
if (nodeIds.isEmpty() == false) {
// mark all node as fetching and go ahead and async fetch them
// use a unique round id to detect stale responses in processAsyncFetch
final long fetchingRound = round.incrementAndGet();
for (NodeEntry<T> nodeEntry : nodesToFetch) {
nodeEntry.markAsFetching(fetchingRound);
}
DiscoveryNode[] discoNodesToFetch = nodesToFetch.stream()
.map(NodeEntry::getNodeId)
cache.markAsFetching(nodeIds, fetchingRound);
DiscoveryNode[] discoNodesToFetch = nodeIds.stream()
.map(nodes::get)
.toArray(DiscoveryNode[]::new);
asyncFetch(discoNodesToFetch, fetchingRound);
}

// if we are still fetching, return null to indicate it
if (hasAnyNodeFetching(cache)) {
if (cache.hasAnyNodeFetching()) {
return new FetchResult<>(null, emptyMap());
} else {
// nothing to fetch, yay, build the return value
Map<DiscoveryNode, T> fetchData = new HashMap<>();
Set<String> failedNodes = new HashSet<>();
for (Iterator<Map.Entry<String, NodeEntry<T>>> it = cache.entrySet().iterator(); it.hasNext();) {
Map.Entry<String, NodeEntry<T>> entry = it.next();
String nodeId = entry.getKey();
NodeEntry<T> nodeEntry = entry.getValue();

DiscoveryNode node = nodes.get(nodeId);
if (node != null) {
if (nodeEntry.isFailed()) {
// if its failed, remove it from the list of nodes, so if this run doesn't work
// we try again next round to fetch it again
it.remove();
failedNodes.add(nodeEntry.getNodeId());
} else {
if (nodeEntry.getValue() != null) {
fetchData.put(node, nodeEntry.getValue());
}
}
}
}
Map<DiscoveryNode, T> fetchData = cache.populateCache(nodes, failedNodes);

Map<ShardId, Set<String>> allIgnoreNodesMap = unmodifiableMap(new HashMap<>(shardToIgnoreNodes));
// clear the nodes to ignore, we had a successful run in fetching everything we can
Expand Down Expand Up @@ -268,77 +235,18 @@ protected synchronized void processAsyncFetch(List<T> responses, List<FailedNode
logger.trace("{} processing fetched [{}] results", reroutingKey, type);

if (responses != null) {
for (T response : responses) {
NodeEntry<T> nodeEntry = cache.get(response.getNode().getId());
if (nodeEntry != null) {
if (nodeEntry.getFetchingRound() != fetchingRound) {
assert nodeEntry.getFetchingRound() > fetchingRound : "node entries only replaced by newer rounds";
logger.trace(
"{} received response for [{}] from node {} for an older fetching round (expected: {} but was: {})",
reroutingKey,
nodeEntry.getNodeId(),
type,
nodeEntry.getFetchingRound(),
fetchingRound
);
} else if (nodeEntry.isFailed()) {
logger.trace(
"{} node {} has failed for [{}] (failure [{}])",
reroutingKey,
nodeEntry.getNodeId(),
type,
nodeEntry.getFailure()
);
} else {
// if the entry is there, for the right fetching round and not marked as failed already, process it
logger.trace("{} marking {} as done for [{}], result is [{}]", reroutingKey, nodeEntry.getNodeId(), type, response);
nodeEntry.doneFetching(response);
}
}
}
cache.processResponses(responses, fetchingRound);
}
if (failures != null) {
for (FailedNodeException failure : failures) {
logger.trace("{} processing failure {} for [{}]", reroutingKey, failure, type);
NodeEntry<T> nodeEntry = cache.get(failure.nodeId());
if (nodeEntry != null) {
if (nodeEntry.getFetchingRound() != fetchingRound) {
assert nodeEntry.getFetchingRound() > fetchingRound : "node entries only replaced by newer rounds";
logger.trace(
"{} received failure for [{}] from node {} for an older fetching round (expected: {} but was: {})",
reroutingKey,
nodeEntry.getNodeId(),
type,
nodeEntry.getFetchingRound(),
fetchingRound
);
} else if (nodeEntry.isFailed() == false) {
// if the entry is there, for the right fetching round and not marked as failed already, process it
Throwable unwrappedCause = ExceptionsHelper.unwrapCause(failure.getCause());
// if the request got rejected or timed out, we need to try it again next time...
if (unwrappedCause instanceof OpenSearchRejectedExecutionException
|| unwrappedCause instanceof ReceiveTimeoutTransportException
|| unwrappedCause instanceof OpenSearchTimeoutException) {
nodeEntry.restartFetching();
} else {
logger.warn(
() -> new ParameterizedMessage(
"{}: failed to list shard for {} on node [{}]",
reroutingKey,
type,
failure.nodeId()
),
failure
);
nodeEntry.doneFetching(failure.getCause());
}
}
}
}
cache.processFailures(failures, fetchingRound);
}
reroute(reroutingKey, "post_response");
}

public int getNumberOfInFlightFetches() {
return cache.getInflightFetches();
}

/**
* Implement this in order to scheduled another round that causes a call to fetch data.
*/
Expand All @@ -351,47 +259,6 @@ synchronized void clearCacheForNode(String nodeId) {
cache.remove(nodeId);
}

/**
* Fills the shard fetched data with new (data) nodes and a fresh NodeEntry, and removes from
* it nodes that are no longer part of the state.
*/
private void fillShardCacheWithDataNodes(Map<String, NodeEntry<T>> shardCache, DiscoveryNodes nodes) {
// verify that all current data nodes are there
for (final DiscoveryNode node : nodes.getDataNodes().values()) {
if (shardCache.containsKey(node.getId()) == false) {
shardCache.put(node.getId(), new NodeEntry<T>(node.getId()));
}
}
// remove nodes that are not longer part of the data nodes set
shardCache.keySet().removeIf(nodeId -> !nodes.nodeExists(nodeId));
}

/**
* Finds all the nodes that need to be fetched. Those are nodes that have no
* data, and are not in fetch mode.
*/
private List<NodeEntry<T>> findNodesToFetch(Map<String, NodeEntry<T>> shardCache) {
List<NodeEntry<T>> nodesToFetch = new ArrayList<>();
for (NodeEntry<T> nodeEntry : shardCache.values()) {
if (nodeEntry.hasData() == false && nodeEntry.isFetching() == false) {
nodesToFetch.add(nodeEntry);
}
}
return nodesToFetch;
}

/**
* Are there any nodes that are fetching data?
*/
private boolean hasAnyNodeFetching(Map<String, NodeEntry<T>> shardCache) {
for (NodeEntry<T> nodeEntry : shardCache.values()) {
if (nodeEntry.isFetching()) {
return true;
}
}
return false;
}

/**
* Async fetches data for the provided shard with the set of nodes that need to be fetched from.
*/
Expand Down Expand Up @@ -460,83 +327,4 @@ public void processAllocation(RoutingAllocation allocation) {

}
}

/**
* A node entry, holding the state of the fetched data for a specific shard
* for a giving node.
*/
static class NodeEntry<T> {
private final String nodeId;
private boolean fetching;
@Nullable
private T value;
private boolean valueSet;
private Throwable failure;
private long fetchingRound;

NodeEntry(String nodeId) {
this.nodeId = nodeId;
}

String getNodeId() {
return this.nodeId;
}

boolean isFetching() {
return fetching;
}

void markAsFetching(long fetchingRound) {
assert fetching == false : "double marking a node as fetching";
this.fetching = true;
this.fetchingRound = fetchingRound;
}

void doneFetching(T value) {
assert fetching : "setting value but not in fetching mode";
assert failure == null : "setting value when failure already set";
this.valueSet = true;
this.value = value;
this.fetching = false;
}

void doneFetching(Throwable failure) {
assert fetching : "setting value but not in fetching mode";
assert valueSet == false : "setting failure when already set value";
assert failure != null : "setting failure can't be null";
this.failure = failure;
this.fetching = false;
}

void restartFetching() {
assert fetching : "restarting fetching, but not in fetching mode";
assert valueSet == false : "value can't be set when restarting fetching";
assert failure == null : "failure can't be set when restarting fetching";
this.fetching = false;
}

boolean isFailed() {
return failure != null;
}

boolean hasData() {
return valueSet || failure != null;
}

Throwable getFailure() {
assert hasData() : "getting failure when data has not been fetched";
return failure;
}

@Nullable
T getValue() {
assert failure == null : "trying to fetch value, but its marked as failed, check isFailed";
assert valueSet : "value is not set, hasn't been fetched yet";
return value;
}

long getFetchingRound() {
return fetchingRound;
}
}
}
Loading

0 comments on commit c26c6c1

Please sign in to comment.