From 30f722892b2184b2b49c5686ca9032511396e07b Mon Sep 17 00:00:00 2001 From: liketic Date: Sat, 16 Dec 2017 16:19:16 +0800 Subject: [PATCH] Calculate sum in Kahan summation algorithm in aggregations (#27807) --- .../metrics/avg/AvgAggregator.java | 18 ++++++++--- .../metrics/stats/StatsAggregator.java | 18 ++++++++--- .../metrics/sum/SumAggregator.java | 17 ++++++++--- .../metrics/StatsAggregatorTests.java | 24 +++++++++++++++ .../metrics/SumAggregatorTests.java | 21 ++++++++++++- .../metrics/avg/AvgAggregatorTests.java | 30 +++++++++++++++---- 6 files changed, 109 insertions(+), 19 deletions(-) diff --git a/core/src/main/java/org/elasticsearch/search/aggregations/metrics/avg/AvgAggregator.java b/core/src/main/java/org/elasticsearch/search/aggregations/metrics/avg/AvgAggregator.java index 0decfa05575e4..53c6b505abd77 100644 --- a/core/src/main/java/org/elasticsearch/search/aggregations/metrics/avg/AvgAggregator.java +++ b/core/src/main/java/org/elasticsearch/search/aggregations/metrics/avg/AvgAggregator.java @@ -46,6 +46,8 @@ public class AvgAggregator extends NumericMetricsAggregator.SingleValue { DoubleArray sums; DocValueFormat format; + private DoubleArray compensations; + public AvgAggregator(String name, ValuesSource.Numeric valuesSource, DocValueFormat formatter, SearchContext context, Aggregator parent, List pipelineAggregators, Map metaData) throws IOException { super(name, context, parent, pipelineAggregators, metaData); @@ -55,6 +57,7 @@ public AvgAggregator(String name, ValuesSource.Numeric valuesSource, DocValueFor final BigArrays bigArrays = context.bigArrays(); counts = bigArrays.newLongArray(1, true); sums = bigArrays.newDoubleArray(1, true); + compensations = bigArrays.newDoubleArray(1, true); } } @@ -76,15 +79,22 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, public void collect(int doc, long bucket) throws IOException { counts = bigArrays.grow(counts, bucket + 1); sums = bigArrays.grow(sums, bucket + 1); + compensations = bigArrays.grow(compensations, bucket + 1); if (values.advanceExact(doc)) { final int valueCount = values.docValueCount(); counts.increment(bucket, valueCount); - double sum = 0; + double sum = sums.get(bucket); + double compensation = compensations.get(bucket); + for (int i = 0; i < valueCount; i++) { - sum += values.nextValue(); + double corrected = values.nextValue() - compensation; + double newSum = sum + corrected; + compensation = (newSum - sum) - corrected; + sum = newSum; } - sums.increment(bucket, sum); + sums.set(bucket, sum); + compensations.set(bucket, compensation); } } }; @@ -113,7 +123,7 @@ public InternalAggregation buildEmptyAggregation() { @Override public void doClose() { - Releasables.close(counts, sums); + Releasables.close(counts, sums, compensations); } } diff --git a/core/src/main/java/org/elasticsearch/search/aggregations/metrics/stats/StatsAggregator.java b/core/src/main/java/org/elasticsearch/search/aggregations/metrics/stats/StatsAggregator.java index cca176bd1ad5f..80639ea9ad3d5 100644 --- a/core/src/main/java/org/elasticsearch/search/aggregations/metrics/stats/StatsAggregator.java +++ b/core/src/main/java/org/elasticsearch/search/aggregations/metrics/stats/StatsAggregator.java @@ -48,6 +48,8 @@ public class StatsAggregator extends NumericMetricsAggregator.MultiValue { DoubleArray mins; DoubleArray maxes; + private DoubleArray compensations; + public StatsAggregator(String name, ValuesSource.Numeric valuesSource, DocValueFormat format, SearchContext context, @@ -59,6 +61,7 @@ public StatsAggregator(String name, ValuesSource.Numeric valuesSource, DocValueF final BigArrays bigArrays = context.bigArrays(); counts = bigArrays.newLongArray(1, true); sums = bigArrays.newDoubleArray(1, true); + compensations = bigArrays.newDoubleArray(1, true); mins = bigArrays.newDoubleArray(1, false); mins.fill(0, mins.size(), Double.POSITIVE_INFINITY); maxes = bigArrays.newDoubleArray(1, false); @@ -88,6 +91,7 @@ public void collect(int doc, long bucket) throws IOException { final long overSize = BigArrays.overSize(bucket + 1); counts = bigArrays.resize(counts, overSize); sums = bigArrays.resize(sums, overSize); + compensations = bigArrays.resize(compensations, overSize); mins = bigArrays.resize(mins, overSize); maxes = bigArrays.resize(maxes, overSize); mins.fill(from, overSize, Double.POSITIVE_INFINITY); @@ -97,16 +101,22 @@ public void collect(int doc, long bucket) throws IOException { if (values.advanceExact(doc)) { final int valuesCount = values.docValueCount(); counts.increment(bucket, valuesCount); - double sum = 0; double min = mins.get(bucket); double max = maxes.get(bucket); + double sum = sums.get(bucket); + double compensation = compensations.get(bucket); + for (int i = 0; i < valuesCount; i++) { double value = values.nextValue(); - sum += value; + double corrected = value - compensation; + double newSum = sum + corrected; + compensation = (newSum - sum) - corrected; + sum = newSum; min = Math.min(min, value); max = Math.max(max, value); } - sums.increment(bucket, sum); + sums.set(bucket, sum); + compensations.set(bucket, compensation); mins.set(bucket, min); maxes.set(bucket, max); } @@ -164,6 +174,6 @@ public InternalAggregation buildEmptyAggregation() { @Override public void doClose() { - Releasables.close(counts, maxes, mins, sums); + Releasables.close(counts, maxes, mins, sums, compensations); } } diff --git a/core/src/main/java/org/elasticsearch/search/aggregations/metrics/sum/SumAggregator.java b/core/src/main/java/org/elasticsearch/search/aggregations/metrics/sum/SumAggregator.java index bd325b39373e5..cd4f381491178 100644 --- a/core/src/main/java/org/elasticsearch/search/aggregations/metrics/sum/SumAggregator.java +++ b/core/src/main/java/org/elasticsearch/search/aggregations/metrics/sum/SumAggregator.java @@ -43,6 +43,7 @@ public class SumAggregator extends NumericMetricsAggregator.SingleValue { private final DocValueFormat format; private DoubleArray sums; + private DoubleArray compensations; SumAggregator(String name, ValuesSource.Numeric valuesSource, DocValueFormat formatter, SearchContext context, Aggregator parent, List pipelineAggregators, Map metaData) throws IOException { @@ -51,6 +52,7 @@ public class SumAggregator extends NumericMetricsAggregator.SingleValue { this.format = formatter; if (valuesSource != null) { sums = context.bigArrays().newDoubleArray(1, true); + compensations = context.bigArrays().newDoubleArray(1, true); } } @@ -71,13 +73,20 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, @Override public void collect(int doc, long bucket) throws IOException { sums = bigArrays.grow(sums, bucket + 1); + compensations = bigArrays.grow(compensations, bucket + 1); + if (values.advanceExact(doc)) { final int valuesCount = values.docValueCount(); - double sum = 0; + double sum = sums.get(bucket); + double compensation = compensations.get(bucket); for (int i = 0; i < valuesCount; i++) { - sum += values.nextValue(); + double corrected = values.nextValue() - compensation; + double newSum = sum + corrected; + compensation = (newSum - sum) - corrected; + sum = newSum; } - sums.increment(bucket, sum); + compensations.set(bucket, compensation); + sums.set(bucket, sum); } } }; @@ -106,6 +115,6 @@ public InternalAggregation buildEmptyAggregation() { @Override public void doClose() { - Releasables.close(sums); + Releasables.close(sums, compensations); } } diff --git a/core/src/test/java/org/elasticsearch/search/aggregations/metrics/StatsAggregatorTests.java b/core/src/test/java/org/elasticsearch/search/aggregations/metrics/StatsAggregatorTests.java index 7286c7de0fed5..e50d89caa0f4d 100644 --- a/core/src/test/java/org/elasticsearch/search/aggregations/metrics/StatsAggregatorTests.java +++ b/core/src/test/java/org/elasticsearch/search/aggregations/metrics/StatsAggregatorTests.java @@ -19,6 +19,7 @@ package org.elasticsearch.search.aggregations.metrics; import org.apache.lucene.document.Document; +import org.apache.lucene.document.DoubleDocValuesField; import org.apache.lucene.document.SortedNumericDocValuesField; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.RandomIndexWriter; @@ -36,6 +37,8 @@ import java.io.IOException; import java.util.function.Consumer; +import static java.util.Collections.singleton; + public class StatsAggregatorTests extends AggregatorTestCase { static final double TOLERANCE = 1e-10; @@ -113,6 +116,27 @@ public void testRandomLongs() throws IOException { ); } + public void testSummationAccuracy() throws IOException { + MappedFieldType ft = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.DOUBLE); + final String fieldName = "field"; + ft.setName(fieldName); + testCase(ft, + iw -> { + double[] values = new double[]{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.9, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7}; + for (double value : values) { + iw.addDocument(singleton(new DoubleDocValuesField(fieldName, value))); + } + }, + stats -> { + assertEquals(15, stats.getCount()); + assertEquals(0.9, stats.getAvg(), 0d); + assertEquals(13.5, stats.getSum(), 0d); + assertEquals(1.7, stats.getMax(), 0d); + assertEquals(0.1, stats.getMin(), 0d); + } + ); + } + public void testCase(MappedFieldType ft, CheckedConsumer buildIndex, Consumer verify) throws IOException { diff --git a/core/src/test/java/org/elasticsearch/search/aggregations/metrics/SumAggregatorTests.java b/core/src/test/java/org/elasticsearch/search/aggregations/metrics/SumAggregatorTests.java index ff9888a4981d3..c0c3b090b9be3 100644 --- a/core/src/test/java/org/elasticsearch/search/aggregations/metrics/SumAggregatorTests.java +++ b/core/src/test/java/org/elasticsearch/search/aggregations/metrics/SumAggregatorTests.java @@ -18,6 +18,7 @@ */ package org.elasticsearch.search.aggregations.metrics; +import org.apache.lucene.document.DoubleDocValuesField; import org.apache.lucene.document.Field; import org.apache.lucene.document.NumericDocValuesField; import org.apache.lucene.document.SortedDocValuesField; @@ -116,10 +117,28 @@ public void testStringField() throws IOException { "Re-index with correct docvalues type.", e.getMessage()); } + public void testSummationAccuracy() throws IOException { + testCase(new MatchAllDocsQuery(), + iw -> { + double[] values = new double[]{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7}; + for (double value : values) { + iw.addDocument(singleton(new DoubleDocValuesField(FIELD_NAME, value))); + } + }, + count -> assertEquals(15.3, count.getValue(), 0d), + NumberFieldMapper.NumberType.DOUBLE); + } + private void testCase(Query query, CheckedConsumer indexer, Consumer verify) throws IOException { + testCase(query, indexer, verify, NumberFieldMapper.NumberType.LONG); + } + private void testCase(Query query, + CheckedConsumer indexer, + Consumer verify, + NumberFieldMapper.NumberType fieldNumberType) throws IOException { try (Directory directory = newDirectory()) { try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) { indexer.accept(indexWriter); @@ -128,7 +147,7 @@ private void testCase(Query query, try (IndexReader indexReader = DirectoryReader.open(directory)) { IndexSearcher indexSearcher = newSearcher(indexReader, true, true); - MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.LONG); + MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(fieldNumberType); fieldType.setName(FIELD_NAME); fieldType.setHasDocValues(true); diff --git a/core/src/test/java/org/elasticsearch/search/aggregations/metrics/avg/AvgAggregatorTests.java b/core/src/test/java/org/elasticsearch/search/aggregations/metrics/avg/AvgAggregatorTests.java index 2849ede447b60..bfc1c04200665 100644 --- a/core/src/test/java/org/elasticsearch/search/aggregations/metrics/avg/AvgAggregatorTests.java +++ b/core/src/test/java/org/elasticsearch/search/aggregations/metrics/avg/AvgAggregatorTests.java @@ -19,6 +19,7 @@ package org.elasticsearch.search.aggregations.metrics.avg; +import org.apache.lucene.document.DoubleDocValuesField; import org.apache.lucene.document.IntPoint; import org.apache.lucene.document.NumericDocValuesField; import org.apache.lucene.document.SortedNumericDocValuesField; @@ -34,9 +35,6 @@ import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.mapper.NumberFieldMapper; import org.elasticsearch.search.aggregations.AggregatorTestCase; -import org.elasticsearch.search.aggregations.metrics.avg.AvgAggregationBuilder; -import org.elasticsearch.search.aggregations.metrics.avg.AvgAggregator; -import org.elasticsearch.search.aggregations.metrics.avg.InternalAvg; import java.io.IOException; import java.util.Arrays; @@ -103,8 +101,28 @@ public void testQueryFiltersAll() throws IOException { }); } - private void testCase(Query query, CheckedConsumer buildIndex, Consumer verify) - throws IOException { + public void testSummationAccuracy() throws IOException { + testCase(new MatchAllDocsQuery(), + iw -> { + double[] values = new double[]{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.9, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7}; + for (double value : values) { + iw.addDocument(singleton(new DoubleDocValuesField("number", value))); + } + }, + avg -> assertEquals(0.9, avg.getValue(), 0d), + NumberFieldMapper.NumberType.DOUBLE); + } + + private void testCase(Query query, + CheckedConsumer buildIndex, + Consumer verify) throws IOException { + testCase(query, buildIndex, verify, NumberFieldMapper.NumberType.LONG); + } + + private void testCase(Query query, + CheckedConsumer buildIndex, + Consumer verify, + NumberFieldMapper.NumberType fieldNumberType) throws IOException { Directory directory = newDirectory(); RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory); buildIndex.accept(indexWriter); @@ -114,7 +132,7 @@ private void testCase(Query query, CheckedConsumer