From 03374b3273e7377386782fb60d973140b0c7fb25 Mon Sep 17 00:00:00 2001 From: Colin Goodheart-Smithe Date: Thu, 9 Oct 2014 12:36:02 +0100 Subject: [PATCH] Aggregations: Fixes scripted metrics aggregation when used as a sub aggregation The scripted metric aggregation is now a PER_BUCKET aggregation so that parent buckets are evaluated independently. Also the params and reduceParams are copied for each instance of the aggregator (each parent bucket) so modifications to the values are kept only within the scope of its parent bucket Closes #8036 --- .../metrics/MetricsAggregator.java | 6 +- .../scripted/ScriptedMetricAggregator.java | 46 ++++++++++++-- .../metrics/ScriptedMetricTests.java | 61 ++++++++++++++++++- 3 files changed, 107 insertions(+), 6 deletions(-) diff --git a/src/main/java/org/elasticsearch/search/aggregations/metrics/MetricsAggregator.java b/src/main/java/org/elasticsearch/search/aggregations/metrics/MetricsAggregator.java index 678ec173b2037..27ae902f5fc0a 100644 --- a/src/main/java/org/elasticsearch/search/aggregations/metrics/MetricsAggregator.java +++ b/src/main/java/org/elasticsearch/search/aggregations/metrics/MetricsAggregator.java @@ -26,6 +26,10 @@ public abstract class MetricsAggregator extends Aggregator { protected MetricsAggregator(String name, long estimatedBucketsCount, AggregationContext context, Aggregator parent) { - super(name, BucketAggregationMode.MULTI_BUCKETS, AggregatorFactories.EMPTY, estimatedBucketsCount, context, parent); + this(name, estimatedBucketsCount, BucketAggregationMode.MULTI_BUCKETS, context, parent); + } + + protected MetricsAggregator(String name, long estimatedBucketsCount, BucketAggregationMode bucketAggregationMode, AggregationContext context, Aggregator parent) { + super(name, bucketAggregationMode, AggregatorFactories.EMPTY, estimatedBucketsCount, context, parent); } } diff --git a/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregator.java b/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregator.java index 6a21f49d01ba4..8f71ab33f63a6 100644 --- a/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregator.java +++ b/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregator.java @@ -24,15 +24,17 @@ import org.elasticsearch.script.ScriptService; import org.elasticsearch.script.ScriptService.ScriptType; import org.elasticsearch.script.SearchScript; +import org.elasticsearch.search.SearchParseException; import org.elasticsearch.search.aggregations.Aggregator; import org.elasticsearch.search.aggregations.AggregatorFactory; import org.elasticsearch.search.aggregations.InternalAggregation; import org.elasticsearch.search.aggregations.metrics.MetricsAggregator; import org.elasticsearch.search.aggregations.support.AggregationContext; +import org.elasticsearch.search.internal.SearchContext; import java.io.IOException; -import java.util.HashMap; -import java.util.Map; +import java.util.*; +import java.util.Map.Entry; public class ScriptedMetricAggregator extends MetricsAggregator { @@ -51,7 +53,7 @@ public class ScriptedMetricAggregator extends MetricsAggregator { protected ScriptedMetricAggregator(String name, String scriptLang, ScriptType initScriptType, String initScript, ScriptType mapScriptType, String mapScript, ScriptType combineScriptType, String combineScript, ScriptType reduceScriptType, String reduceScript, Map params, Map reduceParams, AggregationContext context, Aggregator parent) { - super(name, 1, context, parent); + super(name, 1, BucketAggregationMode.PER_BUCKET, context, parent); this.scriptService = context.searchContext().scriptService(); this.scriptLang = scriptLang; this.reduceScriptType = reduceScriptType; @@ -59,7 +61,7 @@ protected ScriptedMetricAggregator(String name, String scriptLang, ScriptType in this.params = new HashMap<>(); this.params.put("_agg", new HashMap<>()); } else { - this.params = params; + this.params = new HashMap<>(params); } if (reduceParams == null) { this.reduceParams = new HashMap<>(); @@ -142,9 +144,45 @@ public Factory(String name, String scriptLang, ScriptType initScriptType, String @Override public Aggregator create(AggregationContext context, Aggregator parent, long expectedBucketsCount) { + Map params = null; + if (this.params != null) { + params = deepCopyParams(this.params, context.searchContext()); + } + Map reduceParams = null; + if (this.reduceParams != null) { + reduceParams = deepCopyParams(this.reduceParams, context.searchContext()); + } return new ScriptedMetricAggregator(name, scriptLang, initScriptType, initScript, mapScriptType, mapScript, combineScriptType, combineScript, reduceScriptType, reduceScript, params, reduceParams, context, parent); } + + @SuppressWarnings({ "unchecked" }) + private static T deepCopyParams(T original, SearchContext context) { + T clone; + if (original instanceof Map) { + Map originalMap = (Map) original; + Map clonedMap = new HashMap<>(); + for (Entry e : originalMap.entrySet()) { + clonedMap.put(deepCopyParams(e.getKey(), context), deepCopyParams(e.getValue(), context)); + } + clone = (T) clonedMap; + } else if (original instanceof List) { + List originalList = (List) original; + List clonedList = new ArrayList(); + for (Object o : originalList) { + clonedList.add(deepCopyParams(o, context)); + } + clone = (T) clonedList; + } else if (original instanceof String || original instanceof Integer || original instanceof Long || original instanceof Short + || original instanceof Byte || original instanceof Float || original instanceof Double || original instanceof Character + || original instanceof Boolean) { + clone = original; + } else { + throw new SearchParseException(context, "Can only clone primitives, String, ArrayList, and HashMap. Found: " + + original.getClass().getCanonicalName()); + } + return clone; + } } diff --git a/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricTests.java b/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricTests.java index 6e564cd4bd104..df7b8ba906e3a 100644 --- a/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricTests.java +++ b/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricTests.java @@ -25,7 +25,9 @@ import org.elasticsearch.common.settings.ImmutableSettings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.search.aggregations.Aggregation; +import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.search.aggregations.bucket.histogram.Histogram; +import org.elasticsearch.search.aggregations.bucket.histogram.Histogram.Bucket; import org.elasticsearch.search.aggregations.metrics.scripted.ScriptedMetric; import org.elasticsearch.test.ElasticsearchIntegrationTest; import org.elasticsearch.test.ElasticsearchIntegrationTest.ClusterScope; @@ -59,7 +61,8 @@ public void setupSuiteScopeCluster() throws Exception { numDocs = randomIntBetween(10, 100); for (int i = 0; i < numDocs; i++) { builders.add(client().prepareIndex("idx", "type", "" + i).setSource( - jsonBuilder().startObject().field("value", randomAsciiOfLengthBetween(5, 15)).endObject())); + jsonBuilder().startObject().field("value", randomAsciiOfLengthBetween(5, 15)) + .field("l_value", i).endObject())); } indexRandom(true, builders); @@ -561,6 +564,62 @@ public void testInitMapCombineReduce_withParams_File() { assertThat(((Number) object).longValue(), equalTo(numDocs * 3)); } + @Test + public void testInitMapCombineReduce_withParams_asSubAgg() { + Map varsMap = new HashMap<>(); + varsMap.put("multiplier", 1); + Map params = new HashMap<>(); + params.put("_agg", new ArrayList<>()); + params.put("vars", varsMap); + + SearchResponse response = client() + .prepareSearch("idx") + .setQuery(matchAllQuery()).setSize(1000) + .addAggregation( + histogram("histo") + .field("l_value") + .interval(1) + .subAggregation( + scriptedMetric("scripted") + .params(params) + .initScript("vars.multiplier = 3") + .mapScript("_agg.add(vars.multiplier)") + .combineScript( + "newaggregation = []; sum = 0;for (a in _agg) { sum += a}; newaggregation.add(sum); return newaggregation") + .reduceScript( + "newaggregation = []; sum = 0;for (aggregation in _aggs) { for (a in aggregation) { sum += a} }; newaggregation.add(sum); return newaggregation"))) + .execute().actionGet(); + assertSearchResponse(response); + assertThat(response.getHits().getTotalHits(), equalTo(numDocs)); + Aggregation aggregation = response.getAggregations().get("histo"); + assertThat(aggregation, notNullValue()); + assertThat(aggregation, instanceOf(Histogram.class)); + Histogram histoAgg = (Histogram) aggregation; + assertThat(histoAgg.getName(), equalTo("histo")); + List buckets = histoAgg.getBuckets(); + assertThat(buckets, notNullValue()); + for (Bucket b : buckets) { + assertThat(b, notNullValue()); + assertThat(b.getDocCount(), equalTo(1l)); + Aggregations subAggs = b.getAggregations(); + assertThat(subAggs, notNullValue()); + assertThat(subAggs.asList().size(), equalTo(1)); + Aggregation subAgg = subAggs.get("scripted"); + assertThat(subAgg, notNullValue()); + assertThat(subAgg, instanceOf(ScriptedMetric.class)); + ScriptedMetric scriptedMetricAggregation = (ScriptedMetric) subAgg; + assertThat(scriptedMetricAggregation.getName(), equalTo("scripted")); + assertThat(scriptedMetricAggregation.aggregation(), notNullValue()); + assertThat(scriptedMetricAggregation.aggregation(), instanceOf(ArrayList.class)); + List aggregationList = (List) scriptedMetricAggregation.aggregation(); + assertThat(aggregationList.size(), equalTo(1)); + Object object = aggregationList.get(0); + assertThat(object, notNullValue()); + assertThat(object, instanceOf(Number.class)); + assertThat(((Number) object).longValue(), equalTo(3l)); + } + } + @Test public void testEmptyAggregation() throws Exception { Map varsMap = new HashMap<>();