Skip to content

Commit

Permalink
Calculate sum in Kahan summation algorithm in aggregations (elastic#2…
Browse files Browse the repository at this point in the history
  • Loading branch information
liketic committed Dec 16, 2017
1 parent c93cc1b commit 30f7228
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ public class AvgAggregator extends NumericMetricsAggregator.SingleValue {
DoubleArray sums;
DocValueFormat format;

private DoubleArray compensations;

public AvgAggregator(String name, ValuesSource.Numeric valuesSource, DocValueFormat formatter, SearchContext context,
Aggregator parent, List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) throws IOException {
super(name, context, parent, pipelineAggregators, metaData);
Expand All @@ -55,6 +57,7 @@ public AvgAggregator(String name, ValuesSource.Numeric valuesSource, DocValueFor
final BigArrays bigArrays = context.bigArrays();
counts = bigArrays.newLongArray(1, true);
sums = bigArrays.newDoubleArray(1, true);
compensations = bigArrays.newDoubleArray(1, true);
}
}

Expand All @@ -76,15 +79,22 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
public void collect(int doc, long bucket) throws IOException {
counts = bigArrays.grow(counts, bucket + 1);
sums = bigArrays.grow(sums, bucket + 1);
compensations = bigArrays.grow(compensations, bucket + 1);

if (values.advanceExact(doc)) {
final int valueCount = values.docValueCount();
counts.increment(bucket, valueCount);
double sum = 0;
double sum = sums.get(bucket);
double compensation = compensations.get(bucket);

for (int i = 0; i < valueCount; i++) {
sum += values.nextValue();
double corrected = values.nextValue() - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
}
sums.increment(bucket, sum);
sums.set(bucket, sum);
compensations.set(bucket, compensation);
}
}
};
Expand Down Expand Up @@ -113,7 +123,7 @@ public InternalAggregation buildEmptyAggregation() {

@Override
public void doClose() {
Releasables.close(counts, sums);
Releasables.close(counts, sums, compensations);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ public class StatsAggregator extends NumericMetricsAggregator.MultiValue {
DoubleArray mins;
DoubleArray maxes;

private DoubleArray compensations;


public StatsAggregator(String name, ValuesSource.Numeric valuesSource, DocValueFormat format,
SearchContext context,
Expand All @@ -59,6 +61,7 @@ public StatsAggregator(String name, ValuesSource.Numeric valuesSource, DocValueF
final BigArrays bigArrays = context.bigArrays();
counts = bigArrays.newLongArray(1, true);
sums = bigArrays.newDoubleArray(1, true);
compensations = bigArrays.newDoubleArray(1, true);
mins = bigArrays.newDoubleArray(1, false);
mins.fill(0, mins.size(), Double.POSITIVE_INFINITY);
maxes = bigArrays.newDoubleArray(1, false);
Expand Down Expand Up @@ -88,6 +91,7 @@ public void collect(int doc, long bucket) throws IOException {
final long overSize = BigArrays.overSize(bucket + 1);
counts = bigArrays.resize(counts, overSize);
sums = bigArrays.resize(sums, overSize);
compensations = bigArrays.resize(compensations, overSize);
mins = bigArrays.resize(mins, overSize);
maxes = bigArrays.resize(maxes, overSize);
mins.fill(from, overSize, Double.POSITIVE_INFINITY);
Expand All @@ -97,16 +101,22 @@ public void collect(int doc, long bucket) throws IOException {
if (values.advanceExact(doc)) {
final int valuesCount = values.docValueCount();
counts.increment(bucket, valuesCount);
double sum = 0;
double min = mins.get(bucket);
double max = maxes.get(bucket);
double sum = sums.get(bucket);
double compensation = compensations.get(bucket);

for (int i = 0; i < valuesCount; i++) {
double value = values.nextValue();
sum += value;
double corrected = value - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
min = Math.min(min, value);
max = Math.max(max, value);
}
sums.increment(bucket, sum);
sums.set(bucket, sum);
compensations.set(bucket, compensation);
mins.set(bucket, min);
maxes.set(bucket, max);
}
Expand Down Expand Up @@ -164,6 +174,6 @@ public InternalAggregation buildEmptyAggregation() {

@Override
public void doClose() {
Releasables.close(counts, maxes, mins, sums);
Releasables.close(counts, maxes, mins, sums, compensations);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ public class SumAggregator extends NumericMetricsAggregator.SingleValue {
private final DocValueFormat format;

private DoubleArray sums;
private DoubleArray compensations;

SumAggregator(String name, ValuesSource.Numeric valuesSource, DocValueFormat formatter, SearchContext context,
Aggregator parent, List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) throws IOException {
Expand All @@ -51,6 +52,7 @@ public class SumAggregator extends NumericMetricsAggregator.SingleValue {
this.format = formatter;
if (valuesSource != null) {
sums = context.bigArrays().newDoubleArray(1, true);
compensations = context.bigArrays().newDoubleArray(1, true);
}
}

Expand All @@ -71,13 +73,20 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
@Override
public void collect(int doc, long bucket) throws IOException {
sums = bigArrays.grow(sums, bucket + 1);
compensations = bigArrays.grow(compensations, bucket + 1);

if (values.advanceExact(doc)) {
final int valuesCount = values.docValueCount();
double sum = 0;
double sum = sums.get(bucket);
double compensation = compensations.get(bucket);
for (int i = 0; i < valuesCount; i++) {
sum += values.nextValue();
double corrected = values.nextValue() - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
}
sums.increment(bucket, sum);
compensations.set(bucket, compensation);
sums.set(bucket, sum);
}
}
};
Expand Down Expand Up @@ -106,6 +115,6 @@ public InternalAggregation buildEmptyAggregation() {

@Override
public void doClose() {
Releasables.close(sums);
Releasables.close(sums, compensations);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.elasticsearch.search.aggregations.metrics;

import org.apache.lucene.document.Document;
import org.apache.lucene.document.DoubleDocValuesField;
import org.apache.lucene.document.SortedNumericDocValuesField;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.RandomIndexWriter;
Expand All @@ -36,6 +37,8 @@
import java.io.IOException;
import java.util.function.Consumer;

import static java.util.Collections.singleton;

public class StatsAggregatorTests extends AggregatorTestCase {
static final double TOLERANCE = 1e-10;

Expand Down Expand Up @@ -113,6 +116,27 @@ public void testRandomLongs() throws IOException {
);
}

public void testSummationAccuracy() throws IOException {
MappedFieldType ft = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.DOUBLE);
final String fieldName = "field";
ft.setName(fieldName);
testCase(ft,
iw -> {
double[] values = new double[]{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.9, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7};
for (double value : values) {
iw.addDocument(singleton(new DoubleDocValuesField(fieldName, value)));
}
},
stats -> {
assertEquals(15, stats.getCount());
assertEquals(0.9, stats.getAvg(), 0d);
assertEquals(13.5, stats.getSum(), 0d);
assertEquals(1.7, stats.getMax(), 0d);
assertEquals(0.1, stats.getMin(), 0d);
}
);
}

public void testCase(MappedFieldType ft,
CheckedConsumer<RandomIndexWriter, IOException> buildIndex,
Consumer<InternalStats> verify) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package org.elasticsearch.search.aggregations.metrics;

import org.apache.lucene.document.DoubleDocValuesField;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.NumericDocValuesField;
import org.apache.lucene.document.SortedDocValuesField;
Expand Down Expand Up @@ -116,10 +117,28 @@ public void testStringField() throws IOException {
"Re-index with correct docvalues type.", e.getMessage());
}

public void testSummationAccuracy() throws IOException {
testCase(new MatchAllDocsQuery(),
iw -> {
double[] values = new double[]{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7};
for (double value : values) {
iw.addDocument(singleton(new DoubleDocValuesField(FIELD_NAME, value)));
}
},
count -> assertEquals(15.3, count.getValue(), 0d),
NumberFieldMapper.NumberType.DOUBLE);
}

private void testCase(Query query,
CheckedConsumer<RandomIndexWriter, IOException> indexer,
Consumer<Sum> verify) throws IOException {
testCase(query, indexer, verify, NumberFieldMapper.NumberType.LONG);
}

private void testCase(Query query,
CheckedConsumer<RandomIndexWriter, IOException> indexer,
Consumer<Sum> verify,
NumberFieldMapper.NumberType fieldNumberType) throws IOException {
try (Directory directory = newDirectory()) {
try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {
indexer.accept(indexWriter);
Expand All @@ -128,7 +147,7 @@ private void testCase(Query query,
try (IndexReader indexReader = DirectoryReader.open(directory)) {
IndexSearcher indexSearcher = newSearcher(indexReader, true, true);

MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.LONG);
MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(fieldNumberType);
fieldType.setName(FIELD_NAME);
fieldType.setHasDocValues(true);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package org.elasticsearch.search.aggregations.metrics.avg;

import org.apache.lucene.document.DoubleDocValuesField;
import org.apache.lucene.document.IntPoint;
import org.apache.lucene.document.NumericDocValuesField;
import org.apache.lucene.document.SortedNumericDocValuesField;
Expand All @@ -34,9 +35,6 @@
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.search.aggregations.AggregatorTestCase;
import org.elasticsearch.search.aggregations.metrics.avg.AvgAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.avg.AvgAggregator;
import org.elasticsearch.search.aggregations.metrics.avg.InternalAvg;

import java.io.IOException;
import java.util.Arrays;
Expand Down Expand Up @@ -103,8 +101,28 @@ public void testQueryFiltersAll() throws IOException {
});
}

private void testCase(Query query, CheckedConsumer<RandomIndexWriter, IOException> buildIndex, Consumer<InternalAvg> verify)
throws IOException {
public void testSummationAccuracy() throws IOException {
testCase(new MatchAllDocsQuery(),
iw -> {
double[] values = new double[]{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.9, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7};
for (double value : values) {
iw.addDocument(singleton(new DoubleDocValuesField("number", value)));
}
},
avg -> assertEquals(0.9, avg.getValue(), 0d),
NumberFieldMapper.NumberType.DOUBLE);
}

private void testCase(Query query,
CheckedConsumer<RandomIndexWriter, IOException> buildIndex,
Consumer<InternalAvg> verify) throws IOException {
testCase(query, buildIndex, verify, NumberFieldMapper.NumberType.LONG);
}

private void testCase(Query query,
CheckedConsumer<RandomIndexWriter, IOException> buildIndex,
Consumer<InternalAvg> verify,
NumberFieldMapper.NumberType fieldNumberType) throws IOException {
Directory directory = newDirectory();
RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory);
buildIndex.accept(indexWriter);
Expand All @@ -114,7 +132,7 @@ private void testCase(Query query, CheckedConsumer<RandomIndexWriter, IOExceptio
IndexSearcher indexSearcher = newSearcher(indexReader, true, true);

AvgAggregationBuilder aggregationBuilder = new AvgAggregationBuilder("_name").field("number");
MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.LONG);
MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(fieldNumberType);
fieldType.setName("number");

AvgAggregator aggregator = createAggregator(aggregationBuilder, indexSearcher, fieldType);
Expand Down

0 comments on commit 30f7228

Please sign in to comment.