Skip to content

Commit

Permalink
[BUG] opensearch crashes on closed client connection before search re…
Browse files Browse the repository at this point in the history
…ply (opensearch-project#3626) (opensearch-project#3645)

* [BUG] opensearch crashes on closed client connection before search reply

Signed-off-by: Andriy Redko <[email protected]>

* Addressing code review comments

Signed-off-by: Andriy Redko <[email protected]>
(cherry picked from commit 3dba46e)

Co-authored-by: Andriy Redko <[email protected]>
Signed-off-by: Andriy Redko <[email protected]>
  • Loading branch information
opensearch-trigger-bot[bot] and reta committed Jun 22, 2022
1 parent 6deab10 commit 2897a6c
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,11 @@ private void onShardFailure(final int shardIndex, @Nullable SearchShardTarget sh
}
final int totalOps = this.totalOps.incrementAndGet();
if (totalOps == expectedTotalOps) {
onPhaseDone();
try {
onPhaseDone();
} catch (final Exception ex) {
onPhaseFailure(this, "The phase has failed", ex);
}
} else if (totalOps > expectedTotalOps) {
throw new AssertionError(
"unexpected higher total ops [" + totalOps + "] compared to expected [" + expectedTotalOps + "]",
Expand Down Expand Up @@ -559,7 +563,11 @@ private void successfulShardExecution(SearchShardIterator shardsIt) {
}
final int xTotalOps = totalOps.addAndGet(remainingOpsOnIterator);
if (xTotalOps == expectedTotalOps) {
onPhaseDone();
try {
onPhaseDone();
} catch (final Exception ex) {
onPhaseFailure(this, "The phase has failed", ex);
}
} else if (xTotalOps > expectedTotalOps) {
throw new AssertionError(
"unexpected higher total ops [" + xTotalOps + "] compared to expected [" + expectedTotalOps + "]",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@

package org.opensearch.action.search;

import org.junit.After;
import org.junit.Before;
import org.opensearch.action.ActionListener;
import org.opensearch.action.OriginalIndices;
import org.opensearch.action.support.IndicesOptions;
Expand All @@ -43,25 +45,34 @@
import org.opensearch.index.Index;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.index.shard.ShardId;
import org.opensearch.index.shard.ShardNotFoundException;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.internal.AliasFilter;
import org.opensearch.search.internal.InternalSearchResponse;
import org.opensearch.search.internal.ShardSearchContextId;
import org.opensearch.search.internal.ShardSearchRequest;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.transport.Transport;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiFunction;
import java.util.stream.IntStream;

import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
Expand All @@ -71,13 +82,49 @@ public class AbstractSearchAsyncActionTests extends OpenSearchTestCase {

private final List<Tuple<String, String>> resolvedNodes = new ArrayList<>();
private final Set<ShardSearchContextId> releasedContexts = new CopyOnWriteArraySet<>();
private ExecutorService executor;

@Before
@Override
public void setUp() throws Exception {
super.setUp();
executor = Executors.newFixedThreadPool(1);
}

@After
@Override
public void tearDown() throws Exception {
super.tearDown();
executor.shutdown();
assertTrue(executor.awaitTermination(1, TimeUnit.SECONDS));
}

private AbstractSearchAsyncAction<SearchPhaseResult> createAction(
SearchRequest request,
ArraySearchPhaseResults<SearchPhaseResult> results,
ActionListener<SearchResponse> listener,
final boolean controlled,
final AtomicLong expected
) {
return createAction(
request,
results,
listener,
controlled,
false,
expected,
new SearchShardIterator(null, null, Collections.emptyList(), null)
);
}

private AbstractSearchAsyncAction<SearchPhaseResult> createAction(
SearchRequest request,
ArraySearchPhaseResults<SearchPhaseResult> results,
ActionListener<SearchResponse> listener,
final boolean controlled,
final boolean failExecutePhaseOnShard,
final AtomicLong expected,
final SearchShardIterator... shards
) {
final Runnable runnable;
final TransportSearchAction.SearchTimeProvider timeProvider;
Expand Down Expand Up @@ -105,10 +152,10 @@ private AbstractSearchAsyncAction<SearchPhaseResult> createAction(
Collections.singletonMap("foo", new AliasFilter(new MatchAllQueryBuilder())),
Collections.singletonMap("foo", 2.0f),
Collections.singletonMap("name", Sets.newHashSet("bar", "baz")),
null,
executor,
request,
listener,
new GroupShardsIterator<>(Collections.singletonList(new SearchShardIterator(null, null, Collections.emptyList(), null))),
new GroupShardsIterator<>(Arrays.asList(shards)),
timeProvider,
ClusterState.EMPTY_STATE,
null,
Expand All @@ -126,7 +173,13 @@ protected void executePhaseOnShard(
final SearchShardIterator shardIt,
final SearchShardTarget shard,
final SearchActionListener<SearchPhaseResult> listener
) {}
) {
if (failExecutePhaseOnShard) {
listener.onFailure(new ShardNotFoundException(shardIt.shardId()));
} else {
listener.onResponse(new QuerySearchResult());
}
}

@Override
long buildTookInMillis() {
Expand Down Expand Up @@ -328,6 +381,102 @@ private static ArraySearchPhaseResults<SearchPhaseResult> phaseResults(
return phaseResults;
}

public void testOnShardFailurePhaseDoneFailure() throws InterruptedException {
final Index index = new Index("test", UUID.randomUUID().toString());
final CountDownLatch latch = new CountDownLatch(1);
final AtomicBoolean fail = new AtomicBoolean(true);

final SearchShardIterator[] shards = IntStream.range(0, 5 + randomInt(10))
.mapToObj(i -> new SearchShardIterator(null, new ShardId(index, i), Arrays.asList("n1", "n2", "n3"), null, null, null))
.toArray(SearchShardIterator[]::new);

SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true);
searchRequest.setMaxConcurrentShardRequests(1);

final ArraySearchPhaseResults<SearchPhaseResult> queryResult = new ArraySearchPhaseResults<>(shards.length);
AbstractSearchAsyncAction<SearchPhaseResult> action = createAction(
searchRequest,
queryResult,
new ActionListener<SearchResponse>() {
@Override
public void onResponse(SearchResponse response) {

}

@Override
public void onFailure(Exception e) {
if (fail.compareAndSet(true, false)) {
try {
throw new RuntimeException("Simulated exception");
} finally {
executor.submit(() -> latch.countDown());
}
}
}
},
false,
true,
new AtomicLong(),
shards
);
action.run();
assertTrue(latch.await(1, TimeUnit.SECONDS));

InternalSearchResponse internalSearchResponse = InternalSearchResponse.empty();
SearchResponse searchResponse = action.buildSearchResponse(internalSearchResponse, action.buildShardFailures(), null, null);
assertSame(searchResponse.getAggregations(), internalSearchResponse.aggregations());
assertSame(searchResponse.getSuggest(), internalSearchResponse.suggest());
assertSame(searchResponse.getProfileResults(), internalSearchResponse.profile());
assertSame(searchResponse.getHits(), internalSearchResponse.hits());
assertThat(searchResponse.getSuccessfulShards(), equalTo(0));
}

public void testOnShardSuccessPhaseDoneFailure() throws InterruptedException {
final Index index = new Index("test", UUID.randomUUID().toString());
final CountDownLatch latch = new CountDownLatch(1);
final AtomicBoolean fail = new AtomicBoolean(true);

final SearchShardIterator[] shards = IntStream.range(0, 5 + randomInt(10))
.mapToObj(i -> new SearchShardIterator(null, new ShardId(index, i), Arrays.asList("n1", "n2", "n3"), null, null, null))
.toArray(SearchShardIterator[]::new);

SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true);
searchRequest.setMaxConcurrentShardRequests(1);

final ArraySearchPhaseResults<SearchPhaseResult> queryResult = new ArraySearchPhaseResults<>(shards.length);
AbstractSearchAsyncAction<SearchPhaseResult> action = createAction(
searchRequest,
queryResult,
new ActionListener<SearchResponse>() {
@Override
public void onResponse(SearchResponse response) {
if (fail.compareAndSet(true, false)) {
throw new RuntimeException("Simulated exception");
}
}

@Override
public void onFailure(Exception e) {
executor.submit(() -> latch.countDown());
}
},
false,
false,
new AtomicLong(),
shards
);
action.run();
assertTrue(latch.await(1, TimeUnit.SECONDS));

InternalSearchResponse internalSearchResponse = InternalSearchResponse.empty();
SearchResponse searchResponse = action.buildSearchResponse(internalSearchResponse, action.buildShardFailures(), null, null);
assertSame(searchResponse.getAggregations(), internalSearchResponse.aggregations());
assertSame(searchResponse.getSuggest(), internalSearchResponse.suggest());
assertSame(searchResponse.getProfileResults(), internalSearchResponse.profile());
assertSame(searchResponse.getHits(), internalSearchResponse.hits());
assertThat(searchResponse.getSuccessfulShards(), equalTo(shards.length));
}

private static final class PhaseResult extends SearchPhaseResult {
PhaseResult(ShardSearchContextId contextId) {
this.contextId = contextId;
Expand Down

0 comments on commit 2897a6c

Please sign in to comment.