Skip to content

Commit

Permalink
Calculate sum in Kahan summation algorithm in aggregations (#27807) (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
kel authored and jpountz committed Jan 22, 2018
1 parent f668409 commit 82d782b
Show file tree
Hide file tree
Showing 17 changed files with 559 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ public class AvgAggregator extends NumericMetricsAggregator.SingleValue {

LongArray counts;
DoubleArray sums;
DoubleArray compensations;
DocValueFormat format;

public AvgAggregator(String name, ValuesSource.Numeric valuesSource, DocValueFormat formatter, SearchContext context,
Expand All @@ -55,6 +56,7 @@ public AvgAggregator(String name, ValuesSource.Numeric valuesSource, DocValueFor
final BigArrays bigArrays = context.bigArrays();
counts = bigArrays.newLongArray(1, true);
sums = bigArrays.newDoubleArray(1, true);
compensations = bigArrays.newDoubleArray(1, true);
}
}

Expand All @@ -76,15 +78,29 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
public void collect(int doc, long bucket) throws IOException {
counts = bigArrays.grow(counts, bucket + 1);
sums = bigArrays.grow(sums, bucket + 1);
compensations = bigArrays.grow(compensations, bucket + 1);

if (values.advanceExact(doc)) {
final int valueCount = values.docValueCount();
counts.increment(bucket, valueCount);
double sum = 0;
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
double sum = sums.get(bucket);
double compensation = compensations.get(bucket);

for (int i = 0; i < valueCount; i++) {
sum += values.nextValue();
double value = values.nextValue();
if (Double.isFinite(value) == false) {
sum += value;
} else if (Double.isFinite(sum)) {
double corrected = value - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
}
}
sums.increment(bucket, sum);
sums.set(bucket, sum);
compensations.set(bucket, compensation);
}
}
};
Expand Down Expand Up @@ -113,7 +129,7 @@ public InternalAggregation buildEmptyAggregation() {

@Override
public void doClose() {
Releasables.close(counts, sums);
Releasables.close(counts, sums, compensations);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,20 @@ public String getWriteableName() {
public InternalAvg doReduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
long count = 0;
double sum = 0;
double compensation = 0;
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
for (InternalAggregation aggregation : aggregations) {
count += ((InternalAvg) aggregation).count;
sum += ((InternalAvg) aggregation).sum;
InternalAvg avg = (InternalAvg) aggregation;
count += avg.count;
if (Double.isFinite(avg.sum) == false) {
sum += avg.sum;
} else if (Double.isFinite(sum)) {
double corrected = avg.sum - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
}
}
return new InternalAvg(getName(), sum, count, format, pipelineAggregators(), getMetaData());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,12 +152,23 @@ public InternalStats doReduce(List<InternalAggregation> aggregations, ReduceCont
double min = Double.POSITIVE_INFINITY;
double max = Double.NEGATIVE_INFINITY;
double sum = 0;
double compensation = 0;
for (InternalAggregation aggregation : aggregations) {
InternalStats stats = (InternalStats) aggregation;
count += stats.getCount();
min = Math.min(min, stats.getMin());
max = Math.max(max, stats.getMax());
sum += stats.getSum();
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
double value = stats.getSum();
if (Double.isFinite(value) == false) {
sum += value;
} else if (Double.isFinite(sum)) {
double corrected = value - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
}
}
return new InternalStats(name, count, sum, min, max, format, pipelineAggregators(), getMetaData());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ public class StatsAggregator extends NumericMetricsAggregator.MultiValue {

LongArray counts;
DoubleArray sums;
DoubleArray compensations;
DoubleArray mins;
DoubleArray maxes;

Expand All @@ -59,6 +60,7 @@ public StatsAggregator(String name, ValuesSource.Numeric valuesSource, DocValueF
final BigArrays bigArrays = context.bigArrays();
counts = bigArrays.newLongArray(1, true);
sums = bigArrays.newDoubleArray(1, true);
compensations = bigArrays.newDoubleArray(1, true);
mins = bigArrays.newDoubleArray(1, false);
mins.fill(0, mins.size(), Double.POSITIVE_INFINITY);
maxes = bigArrays.newDoubleArray(1, false);
Expand Down Expand Up @@ -88,6 +90,7 @@ public void collect(int doc, long bucket) throws IOException {
final long overSize = BigArrays.overSize(bucket + 1);
counts = bigArrays.resize(counts, overSize);
sums = bigArrays.resize(sums, overSize);
compensations = bigArrays.resize(compensations, overSize);
mins = bigArrays.resize(mins, overSize);
maxes = bigArrays.resize(maxes, overSize);
mins.fill(from, overSize, Double.POSITIVE_INFINITY);
Expand All @@ -97,16 +100,28 @@ public void collect(int doc, long bucket) throws IOException {
if (values.advanceExact(doc)) {
final int valuesCount = values.docValueCount();
counts.increment(bucket, valuesCount);
double sum = 0;
double min = mins.get(bucket);
double max = maxes.get(bucket);
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
double sum = sums.get(bucket);
double compensation = compensations.get(bucket);

for (int i = 0; i < valuesCount; i++) {
double value = values.nextValue();
sum += value;
if (Double.isFinite(value) == false) {
sum += value;
} else if (Double.isFinite(sum)) {
double corrected = value - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
}
min = Math.min(min, value);
max = Math.max(max, value);
}
sums.increment(bucket, sum);
sums.set(bucket, sum);
compensations.set(bucket, compensation);
mins.set(bucket, min);
maxes.set(bucket, max);
}
Expand Down Expand Up @@ -164,6 +179,6 @@ public InternalAggregation buildEmptyAggregation() {

@Override
public void doClose() {
Releasables.close(counts, maxes, mins, sums);
Releasables.close(counts, maxes, mins, sums, compensations);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,11 @@ public class ExtendedStatsAggregator extends NumericMetricsAggregator.MultiValue

LongArray counts;
DoubleArray sums;
DoubleArray compensations;
DoubleArray mins;
DoubleArray maxes;
DoubleArray sumOfSqrs;
DoubleArray compensationOfSqrs;

public ExtendedStatsAggregator(String name, ValuesSource.Numeric valuesSource, DocValueFormat formatter,
SearchContext context, Aggregator parent, double sigma, List<PipelineAggregator> pipelineAggregators,
Expand All @@ -65,11 +67,13 @@ public ExtendedStatsAggregator(String name, ValuesSource.Numeric valuesSource, D
final BigArrays bigArrays = context.bigArrays();
counts = bigArrays.newLongArray(1, true);
sums = bigArrays.newDoubleArray(1, true);
compensations = bigArrays.newDoubleArray(1, true);
mins = bigArrays.newDoubleArray(1, false);
mins.fill(0, mins.size(), Double.POSITIVE_INFINITY);
maxes = bigArrays.newDoubleArray(1, false);
maxes.fill(0, maxes.size(), Double.NEGATIVE_INFINITY);
sumOfSqrs = bigArrays.newDoubleArray(1, true);
compensationOfSqrs = bigArrays.newDoubleArray(1, true);
}
}

Expand All @@ -95,29 +99,52 @@ public void collect(int doc, long bucket) throws IOException {
final long overSize = BigArrays.overSize(bucket + 1);
counts = bigArrays.resize(counts, overSize);
sums = bigArrays.resize(sums, overSize);
compensations = bigArrays.resize(compensations, overSize);
mins = bigArrays.resize(mins, overSize);
maxes = bigArrays.resize(maxes, overSize);
sumOfSqrs = bigArrays.resize(sumOfSqrs, overSize);
compensationOfSqrs = bigArrays.resize(compensationOfSqrs, overSize);
mins.fill(from, overSize, Double.POSITIVE_INFINITY);
maxes.fill(from, overSize, Double.NEGATIVE_INFINITY);
}

if (values.advanceExact(doc)) {
final int valuesCount = values.docValueCount();
counts.increment(bucket, valuesCount);
double sum = 0;
double sumOfSqr = 0;
double min = mins.get(bucket);
double max = maxes.get(bucket);
// Compute the sum and sum of squires for double values with Kahan summation algorithm
// which is more accurate than naive summation.
double sum = sums.get(bucket);
double compensation = compensations.get(bucket);
double sumOfSqr = sumOfSqrs.get(bucket);
double compensationOfSqr = compensationOfSqrs.get(bucket);
for (int i = 0; i < valuesCount; i++) {
double value = values.nextValue();
sum += value;
sumOfSqr += value * value;
if (Double.isFinite(value) == false) {
sum += value;
sumOfSqr += value * value;
} else {
if (Double.isFinite(sum)) {
double corrected = value - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
}
if (Double.isFinite(sumOfSqr)) {
double correctedOfSqr = value * value - compensationOfSqr;
double newSumOfSqr = sumOfSqr + correctedOfSqr;
compensationOfSqr = (newSumOfSqr - sumOfSqr) - correctedOfSqr;
sumOfSqr = newSumOfSqr;
}
}
min = Math.min(min, value);
max = Math.max(max, value);
}
sums.increment(bucket, sum);
sumOfSqrs.increment(bucket, sumOfSqr);
sums.set(bucket, sum);
compensations.set(bucket, compensation);
sumOfSqrs.set(bucket, sumOfSqr);
compensationOfSqrs.set(bucket, compensationOfSqr);
mins.set(bucket, min);
maxes.set(bucket, max);
}
Expand Down Expand Up @@ -196,6 +223,6 @@ public InternalAggregation buildEmptyAggregation() {

@Override
public void doClose() {
Releasables.close(counts, maxes, mins, sumOfSqrs, sums);
Releasables.close(counts, maxes, mins, sumOfSqrs, compensationOfSqrs, sums, compensations);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public static Metrics resolve(String name) {
private final double sigma;

public InternalExtendedStats(String name, long count, double sum, double min, double max, double sumOfSqrs, double sigma,
DocValueFormat formatter, List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) {
DocValueFormat formatter, List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) {
super(name, count, sum, min, max, formatter, pipelineAggregators, metaData);
this.sumOfSqrs = sumOfSqrs;
this.sigma = sigma;
Expand Down Expand Up @@ -142,16 +142,25 @@ public String getStdDeviationBoundAsString(Bounds bound) {
@Override
public InternalExtendedStats doReduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
double sumOfSqrs = 0;
double compensationOfSqrs = 0;
for (InternalAggregation aggregation : aggregations) {
InternalExtendedStats stats = (InternalExtendedStats) aggregation;
if (stats.sigma != sigma) {
throw new IllegalStateException("Cannot reduce other stats aggregations that have a different sigma");
}
sumOfSqrs += stats.getSumOfSquares();
double value = stats.getSumOfSquares();
if (Double.isFinite(value) == false) {
sumOfSqrs += value;
} else if (Double.isFinite(sumOfSqrs)) {
double correctedOfSqrs = value - compensationOfSqrs;
double newSumOfSqrs = sumOfSqrs + correctedOfSqrs;
compensationOfSqrs = (newSumOfSqrs - sumOfSqrs) - correctedOfSqrs;
sumOfSqrs = newSumOfSqrs;
}
}
final InternalStats stats = super.doReduce(aggregations, reduceContext);
return new InternalExtendedStats(name, stats.getCount(), stats.getSum(), stats.getMin(), stats.getMax(), sumOfSqrs, sigma,
format, pipelineAggregators(), getMetaData());
format, pipelineAggregators(), getMetaData());
}

static class Fields {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public class InternalSum extends InternalNumericMetricsAggregation.SingleValue i
private final double sum;

public InternalSum(String name, double sum, DocValueFormat formatter, List<PipelineAggregator> pipelineAggregators,
Map<String, Object> metaData) {
Map<String, Object> metaData) {
super(name, pipelineAggregators, metaData);
this.sum = sum;
this.format = formatter;
Expand Down Expand Up @@ -73,9 +73,20 @@ public double getValue() {

@Override
public InternalSum doReduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
double sum = 0;
double compensation = 0;
for (InternalAggregation aggregation : aggregations) {
sum += ((InternalSum) aggregation).sum;
double value = ((InternalSum) aggregation).sum;
if (Double.isFinite(value) == false) {
sum += value;
} else if (Double.isFinite(sum)) {
double corrected = value - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
}
}
return new InternalSum(name, sum, format, pipelineAggregators(), getMetaData());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ public class SumAggregator extends NumericMetricsAggregator.SingleValue {
private final DocValueFormat format;

private DoubleArray sums;
private DoubleArray compensations;

SumAggregator(String name, ValuesSource.Numeric valuesSource, DocValueFormat formatter, SearchContext context,
Aggregator parent, List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) throws IOException {
Expand All @@ -51,6 +52,7 @@ public class SumAggregator extends NumericMetricsAggregator.SingleValue {
this.format = formatter;
if (valuesSource != null) {
sums = context.bigArrays().newDoubleArray(1, true);
compensations = context.bigArrays().newDoubleArray(1, true);
}
}

Expand All @@ -71,13 +73,27 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
@Override
public void collect(int doc, long bucket) throws IOException {
sums = bigArrays.grow(sums, bucket + 1);
compensations = bigArrays.grow(compensations, bucket + 1);

if (values.advanceExact(doc)) {
final int valuesCount = values.docValueCount();
double sum = 0;
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
double sum = sums.get(bucket);
double compensation = compensations.get(bucket);
for (int i = 0; i < valuesCount; i++) {
sum += values.nextValue();
double value = values.nextValue();
if (Double.isFinite(value) == false) {
sum += value;
} else if (Double.isFinite(sum)) {
double corrected = value - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
}
}
sums.increment(bucket, sum);
compensations.set(bucket, compensation);
sums.set(bucket, sum);
}
}
};
Expand Down Expand Up @@ -106,6 +122,6 @@ public InternalAggregation buildEmptyAggregation() {

@Override
public void doClose() {
Releasables.close(sums);
Releasables.close(sums, compensations);
}
}
Loading

0 comments on commit 82d782b

Please sign in to comment.