Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

POC: Cap shard failure lists to a fixed small size #104147

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 53 additions & 7 deletions server/src/main/java/org/elasticsearch/ExceptionsHelper.java
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,43 @@ public static void maybeDieOnAnotherThread(final Throwable throwable) {
}

/**
* Deduplicate the failures by exception message and index.
* Deduplicate failures, returning the first 'max' number of deduplicated failures.
* If aggressive=false, failures will be deduplicated by message, index and cause.
* If aggressive=true, failures will be deduplicated only by message.
*
* @param failures array to deduplicate
* @param max limit on size of array returned (beyond that failures are discarded)
* @param aggressive how aggressively to deduplicate
* @return deduplicated array; if failures is null or empty, it will be returned without modification
*/
public static ShardOperationFailedException[] groupBy(ShardOperationFailedException[] failures, int max, boolean aggressive) {
if (failures == null || failures.length == 0) {
return failures;
}

assert max > 0 : "max must be greater than zero";
if (failures.length >= max && aggressive == false) {
return groupBy(failures);
}

// MP TODO: IDEA: this could also add a final Exception (at the n-1 slot) that summarizes all exceptions truncated (not included)
List<ShardOperationFailedException> uniqueFailures = new ArrayList<>();
Set<GroupBy> reasons = new HashSet<>();
for (ShardOperationFailedException failure : failures) {
GroupBy reason = new GroupBy(failure, aggressive == false);
if (reasons.contains(reason) == false) {
reasons.add(reason);
uniqueFailures.add(failure);
}
if (uniqueFailures.size() >= max) {
break;
}
}
return uniqueFailures.toArray(new ShardOperationFailedException[0]);
}

/**
* Deduplicate the failures by exception message, index and cause.
* @param failures array to deduplicate
* @return deduplicated array; if failures is null or empty, it will be returned without modification
*/
Expand All @@ -274,7 +310,7 @@ public static ShardOperationFailedException[] groupBy(ShardOperationFailedExcept
List<ShardOperationFailedException> uniqueFailures = new ArrayList<>();
Set<GroupBy> reasons = new HashSet<>();
for (ShardOperationFailedException failure : failures) {
GroupBy reason = new GroupBy(failure);
GroupBy reason = new GroupBy(failure, true);
if (reasons.contains(reason) == false) {
reasons.add(reason);
uniqueFailures.add(failure);
Expand Down Expand Up @@ -306,8 +342,10 @@ private static class GroupBy {
final String reason;
final String index;
final Class<? extends Throwable> causeType;
private final boolean groupByIndex;

GroupBy(ShardOperationFailedException failure) {
GroupBy(ShardOperationFailedException failure, boolean groupByIndex) {
this.groupByIndex = groupByIndex;
Throwable cause = failure.getCause();
// the index name from the failure contains the cluster alias when using CCS. Ideally failures should be grouped by
// index name and cluster alias. That's why the failure index name has the precedence over the one coming from the cause,
Expand Down Expand Up @@ -335,14 +373,22 @@ public boolean equals(Object o) {
return false;
}
GroupBy groupBy = (GroupBy) o;
return Objects.equals(reason, groupBy.reason)
&& Objects.equals(index, groupBy.index)
&& Objects.equals(causeType, groupBy.causeType);
if (groupByIndex) {
return Objects.equals(reason, groupBy.reason)
&& Objects.equals(index, groupBy.index)
&& Objects.equals(causeType, groupBy.causeType);
} else {
return Objects.equals(reason, groupBy.reason) && Objects.equals(causeType, groupBy.causeType);
}
}

@Override
public int hashCode() {
return Objects.hash(reason, index, causeType);
if (groupByIndex) {
return Objects.hash(reason, index, causeType);
} else {
return Objects.hash(reason, causeType);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ static TransportVersion def(int id) {
public static final TransportVersion NESTED_KNN_MORE_INNER_HITS = def(8_577_00_0);
public static final TransportVersion REQUIRE_DATA_STREAM_ADDED = def(8_578_00_0);
public static final TransportVersion ML_INFERENCE_COHERE_EMBEDDINGS_ADDED = def(8_579_00_0);
public static final TransportVersion SEARCH_RESPONSE_FAILED_SHARD_COUNT_TRACKING = def(8_580_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
* distributed frequencies
*/
abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> extends SearchPhase implements SearchPhaseContext {
public static final int MAX_FAILURES_IN_RESPONSE = 3;
private static final float DEFAULT_INDEX_BOOST = 1.0f;
private final Logger logger;
private final NamedWriteableRegistry namedWriteableRegistry;
Expand Down Expand Up @@ -394,6 +395,7 @@ public final void executeNextPhase(SearchPhase currentPhase, SearchPhase nextPha
* failed or if there was a failure and partial results are not allowed, then we immediately
* fail. Otherwise we continue to the next phase.
*/
// MP TODO: I think this one can use the SearchShardFailures class?
ShardOperationFailedException[] shardSearchFailures = buildShardFailures();
if (shardSearchFailures.length == getNumShards()) {
shardSearchFailures = ExceptionsHelper.groupBy(shardSearchFailures);
Expand Down Expand Up @@ -460,6 +462,27 @@ private void executePhase(SearchPhase phase) {
}
}

private ShardSearchFailures buildShardSearchFailures() {
AtomicArray<ShardSearchFailure> shardFailures = this.shardFailures.get();
if (shardFailures == null) {
return new ShardSearchFailures(0, ShardSearchFailure.EMPTY_ARRAY);
}
List<ShardSearchFailure> entries = shardFailures.asList();
ShardOperationFailedException[] grouped = ExceptionsHelper.groupBy(
entries.toArray(new ShardSearchFailure[0]),
MAX_FAILURES_IN_RESPONSE,
true
);

int size = Math.min(MAX_FAILURES_IN_RESPONSE, grouped.length);
ShardSearchFailure[] retained = new ShardSearchFailure[size];
for (int i = 0; i < size; i++) {
retained[i] = (ShardSearchFailure) grouped[i];
}

return new ShardSearchFailures(entries.size(), retained);
}

private ShardSearchFailure[] buildShardFailures() {
AtomicArray<ShardSearchFailure> shardFailures = this.shardFailures.get();
if (shardFailures == null) {
Expand Down Expand Up @@ -659,12 +682,12 @@ public boolean isPartOfPointInTime(ShardSearchContextId contextId) {

private SearchResponse buildSearchResponse(
SearchResponseSections internalSearchResponse,
ShardSearchFailure[] failures,
ShardSearchFailures failures,
String scrollId,
String searchContextId
) {
int numSuccess = successfulOps.get();
int numFailures = failures.length;
int numFailures = failures.getNumFailures();
assert numSuccess + numFailures == getNumShards()
: "numSuccess(" + numSuccess + ") + numFailures(" + numFailures + ") != totalShards(" + getNumShards() + ")";
return new SearchResponse(
Expand All @@ -686,11 +709,11 @@ boolean buildPointInTimeFromSearchResults() {

@Override
public void sendSearchResponse(SearchResponseSections internalSearchResponse, AtomicArray<SearchPhaseResult> queryResults) {
ShardSearchFailure[] failures = buildShardFailures();
ShardSearchFailures failures = buildShardSearchFailures();
Boolean allowPartialResults = request.allowPartialSearchResults();
assert allowPartialResults != null : "SearchRequest missing setting for allowPartialSearchResults";
if (allowPartialResults == false && failures.length > 0) {
raisePhaseFailure(new SearchPhaseExecutionException("", "Shard failures", null, failures));
if (allowPartialResults == false && failures.getNumFailures() > 0) {
raisePhaseFailure(new SearchPhaseExecutionException("", "Shard failures", null, failures.getFailures()));
} else {
final String scrollId = request.scroll() != null ? TransportSearchHelper.buildScrollId(queryResults) : null;
final String searchContextId;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ private static Throwable deduplicateCause(Throwable cause, ShardSearchFailure[]
return cause;
}

/// MP TODO: IDEA add status to the SearchShardFailures class and change this class to accept that class
/// MP TODO the logic below can go into that class
@Override
public RestStatus status() {
if (shardFailures.length == 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ public class SearchResponse extends ActionResponse implements ChunkedToXContentO
private final int totalShards;
private final int successfulShards;
private final int skippedShards;
private final int failedShards;
private final ShardSearchFailure[] shardFailures;
private final Clusters clusters;
private final long tookInMillis;
Expand Down Expand Up @@ -116,6 +117,11 @@ public SearchResponse(StreamInput in) throws IOException {
tookInMillis = in.readVLong();
skippedShards = in.readVInt();
pointInTimeId = in.readOptionalString();
if (in.getTransportVersion().onOrAfter(TransportVersions.SEARCH_RESPONSE_FAILED_SHARD_COUNT_TRACKING)) {
this.failedShards = in.readVInt();
} else {
this.failedShards = shardFailures.length;
}
}

public SearchResponse(
Expand Down Expand Up @@ -160,7 +166,7 @@ public SearchResponse(
int successfulShards,
int skippedShards,
long tookInMillis,
ShardSearchFailure[] shardFailures,
ShardSearchFailure[] shardFailures, // TODO: change to SearchShardFailures?
Clusters clusters,
String pointInTimeId
) {
Expand All @@ -183,6 +189,38 @@ public SearchResponse(
);
}

/// MP TODO: new one added by me
public SearchResponse(
SearchResponseSections searchResponseSections,
String scrollId,
int totalShards,
int successfulShards,
int skippedShards,
long tookInMillis,
ShardSearchFailures shardFailures,
Clusters clusters,
String pointInTimeId
) {
this(
searchResponseSections.hits,
searchResponseSections.aggregations,
searchResponseSections.suggest,
searchResponseSections.timedOut,
searchResponseSections.terminatedEarly,
searchResponseSections.profileResults,
searchResponseSections.numReducePhases,
scrollId,
totalShards,
successfulShards,
skippedShards,
shardFailures.getNumFailures(),
tookInMillis,
shardFailures.getFailures(),
clusters,
pointInTimeId
);
}

public SearchResponse(
SearchHits hits,
Aggregations aggregations,
Expand All @@ -199,6 +237,45 @@ public SearchResponse(
ShardSearchFailure[] shardFailures,
Clusters clusters,
String pointInTimeId
) {
this(
hits,
aggregations,
suggest,
timedOut,
terminatedEarly,
profileResults,
numReducePhases,
scrollId,
totalShards,
successfulShards,
skippedShards,
shardFailures == null ? 0 : shardFailures.length,
tookInMillis,
shardFailures,
clusters,
pointInTimeId
);
}

/// MP TODO: Newly added - use this from AbstractSearchAsyncAction?
public SearchResponse(
SearchHits hits,
Aggregations aggregations,
Suggest suggest,
boolean timedOut,
Boolean terminatedEarly,
SearchProfileResults profileResults,
int numReducePhases,
String scrollId,
int totalShards,
int successfulShards,
int skippedShards,
int failedShards,
long tookInMillis,
ShardSearchFailure[] shardFailures,
Clusters clusters,
String pointInTimeId
) {
this.hits = hits;
hits.incRef();
Expand All @@ -216,6 +293,7 @@ public SearchResponse(
this.skippedShards = skippedShards;
this.tookInMillis = tookInMillis;
this.shardFailures = shardFailures;
this.failedShards = failedShards;
assert skippedShards <= totalShards : "skipped: " + skippedShards + " total: " + totalShards;
assert scrollId == null || pointInTimeId == null
: "SearchResponse can't have both scrollId [" + scrollId + "] and searchContextId [" + pointInTimeId + "]";
Expand Down Expand Up @@ -326,7 +404,7 @@ public int getSkippedShards() {
* The failed number of shards the search was executed on.
*/
public int getFailedShards() {
return shardFailures.length;
return failedShards; // WAS: shardFailures.length;
}

/**
Expand Down Expand Up @@ -568,6 +646,10 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeVLong(tookInMillis);
out.writeVInt(skippedShards);
out.writeOptionalString(pointInTimeId);

if (out.getTransportVersion().onOrAfter(TransportVersions.SEARCH_RESPONSE_FAILED_SHARD_COUNT_TRACKING)) {
out.writeVInt(failedShards);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -124,9 +123,10 @@ public SearchResponse getMergedResponse(Clusters clusters) {
int totalShards = 0;
int skippedShards = 0;
int successfulShards = 0;
int failedShards = 0;
// the current reduce phase counts as one
int numReducePhases = 1;
List<ShardSearchFailure> failures = new ArrayList<>();
List<ShardSearchFailure> failures = new ArrayList<>(AbstractSearchAsyncAction.MAX_FAILURES_IN_RESPONSE);
Map<String, SearchProfileShardResult> profileResults = new HashMap<>();
List<InternalAggregations> aggs = new ArrayList<>();
Map<ShardIdAndClusterAlias, Integer> shards = new TreeMap<>();
Expand All @@ -140,9 +140,17 @@ public SearchResponse getMergedResponse(Clusters clusters) {
totalShards += searchResponse.getTotalShards();
skippedShards += searchResponse.getSkippedShards();
successfulShards += searchResponse.getSuccessfulShards();
failedShards += searchResponse.getFailedShards();
numReducePhases += searchResponse.getNumReducePhases();

Collections.addAll(failures, searchResponse.getShardFailures());
if (failures.size() < AbstractSearchAsyncAction.MAX_FAILURES_IN_RESPONSE) {
for (ShardSearchFailure shardFailure : searchResponse.getShardFailures()) {
failures.add(shardFailure);
if (failures.size() >= AbstractSearchAsyncAction.MAX_FAILURES_IN_RESPONSE) {
break;
}
}
}

profileResults.putAll(searchResponse.getProfileResults());

Expand Down Expand Up @@ -227,6 +235,7 @@ public SearchResponse getMergedResponse(Clusters clusters) {
totalShards,
successfulShards,
skippedShards,
failedShards,
tookInMillis,
shardFailures,
clusters,
Expand Down
Loading