Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TSDB: Add time series aggs cancellation #83492

Merged
merged 11 commits into from
Feb 15, 2022
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.apache.logging.log4j.LogManager;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionFuture;
import org.elasticsearch.action.DocWriteRequest;
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksResponse;
import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksResponse;
import org.elasticsearch.action.bulk.BulkRequestBuilder;
Expand All @@ -28,6 +29,8 @@
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.IndexMode;
import org.elasticsearch.index.IndexSettings;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.PluginsService;
import org.elasticsearch.rest.RestStatus;
Expand All @@ -36,13 +39,16 @@
import org.elasticsearch.script.ScriptType;
import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.ScriptedMetricAggregationBuilder;
import org.elasticsearch.search.aggregations.timeseries.TimeSeriesAggregationBuilder;
import org.elasticsearch.search.lookup.LeafStoredFieldsLookup;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.tasks.TaskInfo;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.transport.TransportService;
import org.junit.BeforeClass;

import java.time.Instant;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
Expand All @@ -55,9 +61,12 @@
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;

import static org.elasticsearch.index.IndexSettings.TIME_SERIES_END_TIME;
import static org.elasticsearch.index.IndexSettings.TIME_SERIES_START_TIME;
import static org.elasticsearch.index.query.QueryBuilders.matchAllQuery;
import static org.elasticsearch.index.query.QueryBuilders.scriptQuery;
import static org.elasticsearch.search.SearchCancellationIT.ScriptedBlockPlugin.SEARCH_BLOCK_SCRIPT_NAME;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertFailures;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailures;
import static org.hamcrest.Matchers.containsString;
Expand All @@ -69,14 +78,20 @@
@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.SUITE)
public class SearchCancellationIT extends ESIntegTestCase {

private static boolean lowLevelCancellation;

@BeforeClass
public static void init() {
lowLevelCancellation = randomBoolean();
}

@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return Collections.singleton(ScriptedBlockPlugin.class);
}

@Override
protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {
boolean lowLevelCancellation = randomBoolean();
logger.info("Using lowLevelCancellation: {}", lowLevelCancellation);
return Settings.builder()
.put(super.nodeSettings(nodeOrdinal, otherSettings))
Expand Down Expand Up @@ -227,7 +242,12 @@ public void testCancellationDuringAggregation() throws Exception {
new Script(ScriptType.INLINE, "mockscript", ScriptedBlockPlugin.COMBINE_SCRIPT_NAME, Collections.emptyMap())
)
.reduceScript(
new Script(ScriptType.INLINE, "mockscript", ScriptedBlockPlugin.REDUCE_SCRIPT_NAME, Collections.emptyMap())
new Script(
ScriptType.INLINE,
"mockscript",
ScriptedBlockPlugin.REDUCE_BLOCK_SCRIPT_NAME,
Collections.emptyMap()
)
)
)
)
Expand All @@ -238,6 +258,89 @@ public void testCancellationDuringAggregation() throws Exception {
ensureSearchWasCancelled(searchResponse);
}

public void testCancellationDuringTimeSeriesAggregation() throws Exception {
List<ScriptedBlockPlugin> plugins = initBlockFactory();
boolean blockInReduce = false;
int numberOfShards = between(2, 5);
long now = Instant.now().toEpochMilli();
assertAcked(
prepareCreate("test").setSettings(
Settings.builder()
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, numberOfShards)
.put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
.put(IndexSettings.MODE.getKey(), IndexMode.TIME_SERIES.name())
.put(IndexMetadata.INDEX_ROUTING_PATH.getKey(), "dim")
.put(TIME_SERIES_START_TIME.getKey(), now)
.put(TIME_SERIES_END_TIME.getKey(), now + 101)
.build()
).setMapping("""
{
"properties": {
"@timestamp": {"type": "date", "format": "epoch_millis"},
"dim": {"type": "keyword", "time_series_dimension": true}
}
}
""")
);

for (int i = 0; i < 5; i++) {
// Make sure we have a few segments
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk().setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
for (int j = 0; j < 20; j++) {
bulkRequestBuilder.add(
client().prepareIndex("test")
.setOpType(DocWriteRequest.OpType.CREATE)
.setSource("@timestamp", now + i * 20 + j, "val", (double) j, "dim", String.valueOf(i))
);
}
assertNoFailures(bulkRequestBuilder.get());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't really "make sure" we have a few segments - just make it quite likely. I'm guessing if you were to run an index stats after this to confirm we had a few segments it'd fail some percent of the time. Its fine, I think, but maybe comment is optimistic.

}

logger.info("Executing search");
TimeSeriesAggregationBuilder timeSeriesAggregationBuilder = new TimeSeriesAggregationBuilder("test_agg");
ActionFuture<SearchResponse> searchResponse = client().prepareSearch("test")
.setQuery(matchAllQuery())
.addAggregation(
timeSeriesAggregationBuilder.subAggregation(
new ScriptedMetricAggregationBuilder("sub_agg").initScript(
new Script(ScriptType.INLINE, "mockscript", ScriptedBlockPlugin.INIT_SCRIPT_NAME, Collections.emptyMap())
)
.mapScript(
new Script(
ScriptType.INLINE,
"mockscript",
blockInReduce ? ScriptedBlockPlugin.MAP_SCRIPT_NAME : ScriptedBlockPlugin.MAP_BLOCK_SCRIPT_NAME,
Collections.emptyMap()
)
)
.combineScript(
new Script(ScriptType.INLINE, "mockscript", ScriptedBlockPlugin.COMBINE_SCRIPT_NAME, Collections.emptyMap())
)
.reduceScript(
new Script(
ScriptType.INLINE,
"mockscript",
blockInReduce ? ScriptedBlockPlugin.REDUCE_BLOCK_SCRIPT_NAME : ScriptedBlockPlugin.REDUCE_FAIL_SCRIPT_NAME,
Collections.emptyMap()
)
)
)
)
.execute();
awaitForBlock(plugins);
cancelSearch(SearchAction.NAME);
disableBlocks(plugins);

SearchPhaseExecutionException ex = expectThrows(SearchPhaseExecutionException.class, searchResponse::actionGet);
assertThat(ExceptionsHelper.status(ex), equalTo(RestStatus.BAD_REQUEST));
logger.info("All shards failed with", ex);
if (lowLevelCancellation) {
// Ensure that we cancelled in LeafWalker and not in reduce phase
assertThat(ExceptionsHelper.stackTrace(ex), containsString("LeafWalker"));
}

}

public void testCancellationOfScrollSearches() throws Exception {

List<ScriptedBlockPlugin> plugins = initBlockFactory();
Expand Down Expand Up @@ -414,8 +517,11 @@ public static class ScriptedBlockPlugin extends MockScriptPlugin {
static final String SEARCH_BLOCK_SCRIPT_NAME = "search_block";
static final String INIT_SCRIPT_NAME = "init";
static final String MAP_SCRIPT_NAME = "map";
static final String MAP_BLOCK_SCRIPT_NAME = "map_block";
static final String COMBINE_SCRIPT_NAME = "combine";
static final String REDUCE_SCRIPT_NAME = "reduce";
static final String REDUCE_FAIL_SCRIPT_NAME = "reduce_fail";
static final String REDUCE_BLOCK_SCRIPT_NAME = "reduce_block";
static final String TERM_SCRIPT_NAME = "term";

private final AtomicInteger hits = new AtomicInteger();
Expand Down Expand Up @@ -449,10 +555,16 @@ public Map<String, Function<Map<String, Object>, Object>> pluginScripts() {
this::nullScript,
MAP_SCRIPT_NAME,
this::nullScript,
MAP_BLOCK_SCRIPT_NAME,
this::mapBlockScript,
COMBINE_SCRIPT_NAME,
this::nullScript,
REDUCE_SCRIPT_NAME,
REDUCE_BLOCK_SCRIPT_NAME,
this::blockScript,
REDUCE_SCRIPT_NAME,
this::termScript,
REDUCE_FAIL_SCRIPT_NAME,
this::reduceFailScript,
TERM_SCRIPT_NAME,
this::termScript
);
Expand All @@ -474,6 +586,11 @@ private Object searchBlockScript(Map<String, Object> params) {
return true;
}

private Object reduceFailScript(Map<String, Object> params) {
fail("Shouldn't reach reduce");
return true;
}

private Object nullScript(Map<String, Object> params) {
return null;
}
Expand All @@ -483,7 +600,9 @@ private Object blockScript(Map<String, Object> params) {
if (runnable != null) {
runnable.run();
}
LogManager.getLogger(SearchCancellationIT.class).info("Blocking in reduce");
if (shouldBlock.get()) {
LogManager.getLogger(SearchCancellationIT.class).info("Blocking in reduce");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just declare a private static final Logger logger = LogManager.getLogger()?

}
hits.incrementAndGet();
try {
assertBusy(() -> assertFalse(shouldBlock.get()));
Expand All @@ -493,6 +612,23 @@ private Object blockScript(Map<String, Object> params) {
return 42;
}

private Object mapBlockScript(Map<String, Object> params) {
final Runnable runnable = beforeExecution.get();
if (runnable != null) {
runnable.run();
}
if (shouldBlock.get()) {
LogManager.getLogger(SearchCancellationIT.class).info("Blocking in map");
}
hits.incrementAndGet();
try {
assertBusy(() -> assertFalse(shouldBlock.get()));
} catch (Exception e) {
throw new RuntimeException(e);
}
return 1;
}

private Object termScript(Map<String, Object> params) {
return 1;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@
package org.elasticsearch.search.aggregations;

import org.apache.lucene.search.Collector;
import org.elasticsearch.action.search.SearchShardTask;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.search.SearchService;
import org.elasticsearch.search.aggregations.timeseries.TimeSeriesIndexSearcher;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.profile.query.CollectorResult;
import org.elasticsearch.search.profile.query.InternalProfileCollector;
import org.elasticsearch.search.query.QueryPhase;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -40,7 +43,7 @@ public void preProcess(SearchContext context) {
}
if (context.aggregations().factories().context() != null
&& context.aggregations().factories().context().isInSortOrderExecutionRequired()) {
TimeSeriesIndexSearcher searcher = new TimeSeriesIndexSearcher(context.searcher());
TimeSeriesIndexSearcher searcher = new TimeSeriesIndexSearcher(context.searcher(), getCancellationChecks(context));
try {
searcher.search(context.rewrittenQuery(), bucketCollector);
} catch (IOException e) {
Expand All @@ -55,6 +58,36 @@ public void preProcess(SearchContext context) {
}
}

private List<Runnable> getCancellationChecks(SearchContext context) {
List<Runnable> cancellationChecks = new ArrayList<>();
if (context.lowLevelCancellation()) {
// This searching doesn't live beyond this phase, so we don't need to remove query cancellation
cancellationChecks.add(() -> {
final SearchShardTask task = context.getTask();
if (task != null) {
task.ensureNotCancelled();
}
});
}

boolean timeoutSet = context.scrollContext() == null
&& context.timeout() != null
&& context.timeout().equals(SearchService.NO_TIMEOUT) == false;

if (timeoutSet) {
final long startTime = context.getRelativeTimeInMillis();
final long timeout = context.timeout().millis();
final long maxTime = startTime + timeout;
cancellationChecks.add(() -> {
final long time = context.getRelativeTimeInMillis();
if (time > maxTime) {
throw new QueryPhase.TimeExceededException();
}
});
}
return cancellationChecks;
}

public void execute(SearchContext context) {
if (context.aggregations() == null) {
context.queryResult().aggregations(null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SortedDocValues;
import org.apache.lucene.index.SortedNumericDocValues;
import org.apache.lucene.search.BulkScorer;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Query;
Expand All @@ -25,8 +27,10 @@
import org.elasticsearch.index.mapper.TimeSeriesIdFieldMapper;
import org.elasticsearch.search.aggregations.BucketCollector;
import org.elasticsearch.search.aggregations.LeafBucketCollector;
import org.elasticsearch.search.internal.CancellableScorer;

import java.io.IOException;
import java.util.List;

/**
* An IndexSearcher wrapper that executes the searches in time-series indices by traversing them by tsid and timestamp
Expand All @@ -37,14 +41,16 @@ public class TimeSeriesIndexSearcher {
// We need to delegate to the other searcher here as opposed to extending IndexSearcher and inheriting default implementations as the
// IndexSearcher would most of the time be a ContextIndexSearcher that has important logic related to e.g. document-level security.
private final IndexSearcher searcher;
private final List<Runnable> cancellations;

public TimeSeriesIndexSearcher(IndexSearcher searcher) {
public TimeSeriesIndexSearcher(IndexSearcher searcher, List<Runnable> cancellations) {
this.searcher = searcher;
this.cancellations = cancellations;
}

public void search(Query query, BucketCollector bucketCollector) throws IOException {
query = searcher.rewrite(query);
Weight weight = searcher.createWeight(query, bucketCollector.scoreMode(), 1);
Weight weight = wrapWeight(searcher.createWeight(query, bucketCollector.scoreMode(), 1));
PriorityQueue<LeafWalker> queue = new PriorityQueue<>(searcher.getIndexReader().leaves().size()) {
@Override
protected boolean lessThan(LeafWalker a, LeafWalker b) {
Expand Down Expand Up @@ -121,4 +127,43 @@ boolean next() throws IOException {
return false;
}
}

private void checkCancelled() {
for (Runnable r : cancellations) {
r.run();
}
}

private Weight wrapWeight(Weight weight) {
if (cancellations.isEmpty() == false) {
return new Weight(weight.getQuery()) {
@Override
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
throw new UnsupportedOperationException();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Delegate to weight.explain(context, doc) here?

}

@Override
public boolean isCacheable(LeafReaderContext ctx) {
throw new UnsupportedOperationException();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And delegate here too

}

@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
Scorer scorer = weight.scorer(context);
if (scorer != null) {
return new CancellableScorer(scorer, () -> checkCancelled());
} else {
return null;
}
}

@Override
public BulkScorer bulkScorer(LeafReaderContext context) throws IOException {
return weight.bulkScorer(context);
}
};
} else {
return weight;
}
}
}
Loading