Skip to content

Commit

Permalink
Flatten decimal accumulator
Browse files Browse the repository at this point in the history
This improves decimal aggregations performance

BEFORE:
Benchmark                                                  (function)  (groupCount)  (type)  Mode  Cnt       Score       Error  Units
BenchmarkDecimalAggregation.benchmark                             sum            10    LONG  avgt   10      11,782 ±     0,511  ns/op
BenchmarkDecimalAggregation.benchmark                             sum         10000    LONG  avgt   10      20,122 ±     0,773  ns/op
BenchmarkDecimalAggregation.benchmarkEvaluateFinal                sum         10000    LONG  avgt   10  727254,852 ± 21119,807  ns/op
BenchmarkDecimalAggregation.benchmarkEvaluateIntermediate         sum            10    LONG  avgt   10      11,628 ±     0,175  ns/op
BenchmarkDecimalAggregation.benchmarkEvaluateIntermediate         sum         10000    LONG  avgt   10      20,417 ±     0,594  ns/op

AFTER
Benchmark                                                  (function)  (groupCount)  (type)  Mode  Cnt       Score       Error  Units
BenchmarkDecimalAggregation.benchmark                             sum            10    LONG  avgt   10       5,488 ±     0,466  ns/op
BenchmarkDecimalAggregation.benchmark                             sum         10000    LONG  avgt   10       6,471 ±     0,664  ns/op
BenchmarkDecimalAggregation.benchmarkEvaluateFinal                sum         10000    LONG  avgt   10  552729,926 ± 23429,821  ns/op
BenchmarkDecimalAggregation.benchmarkEvaluateIntermediate         sum            10    LONG  avgt   10       5,236 ±     0,279  ns/op
BenchmarkDecimalAggregation.benchmarkEvaluateIntermediate         sum         10000    LONG  avgt   10       6,462 ±     0,323  ns/opfixup
  • Loading branch information
sopel39 committed Oct 18, 2021
1 parent edef782 commit b524793
Show file tree
Hide file tree
Showing 13 changed files with 432 additions and 122 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import io.airlift.bytecode.DynamicClassLoader;
import io.airlift.slice.Slice;
import io.trino.metadata.AggregationFunctionMetadata;
import io.trino.metadata.FunctionArgumentDefinition;
import io.trino.metadata.FunctionBinding;
Expand All @@ -32,10 +31,8 @@
import io.trino.spi.function.AccumulatorState;
import io.trino.spi.function.AccumulatorStateSerializer;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.Decimals;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeSignature;
import io.trino.spi.type.UnscaledDecimal128Arithmetic;

import java.lang.invoke.MethodHandle;
import java.math.BigDecimal;
Expand All @@ -46,18 +43,20 @@
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Iterables.getOnlyElement;
import static io.airlift.slice.SizeOf.SIZE_OF_LONG;
import static io.trino.metadata.FunctionKind.AGGREGATE;
import static io.trino.operator.aggregation.AggregationMetadata.ParameterMetadata;
import static io.trino.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INDEX;
import static io.trino.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INPUT_CHANNEL;
import static io.trino.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.STATE;
import static io.trino.operator.aggregation.AggregationUtils.generateAggregationName;
import static io.trino.spi.type.Decimals.MAX_PRECISION;
import static io.trino.spi.type.Decimals.MAX_SHORT_PRECISION;
import static io.trino.spi.type.Decimals.writeBigDecimal;
import static io.trino.spi.type.Decimals.writeShortDecimal;
import static io.trino.spi.type.TypeSignatureParameter.typeVariable;
import static io.trino.spi.type.UnscaledDecimal128Arithmetic.SIGN_LONG_MASK;
import static io.trino.spi.type.UnscaledDecimal128Arithmetic.UNSCALED_DECIMAL_128_SLICE_LENGTH;
import static io.trino.spi.type.UnscaledDecimal128Arithmetic.addWithOverflow;
import static io.trino.spi.type.UnscaledDecimal128Arithmetic.unscaledDecimalToBigInteger;
import static io.trino.spi.type.VarbinaryType.VARBINARY;
import static io.trino.util.Reflection.methodHandle;
import static java.math.BigDecimal.ROUND_HALF_UP;
Expand All @@ -67,10 +66,6 @@ public class DecimalAverageAggregation
{
public static final DecimalAverageAggregation DECIMAL_AVERAGE_AGGREGATION = new DecimalAverageAggregation();

// Constant references for short/long decimal types for use in operations that only manipulate unscaled values
private static final DecimalType LONG_DECIMAL_TYPE = DecimalType.createDecimalType(MAX_PRECISION, 0);
private static final DecimalType SHORT_DECIMAL_TYPE = DecimalType.createDecimalType(MAX_SHORT_PRECISION, 0);

private static final String NAME = "avg";
private static final MethodHandle SHORT_DECIMAL_INPUT_FUNCTION = methodHandle(DecimalAverageAggregation.class, "inputShortDecimal", LongDecimalWithOverflowAndLongState.class, Block.class, int.class);
private static final MethodHandle LONG_DECIMAL_INPUT_FUNCTION = methodHandle(DecimalAverageAggregation.class, "inputLongDecimal", LongDecimalWithOverflowAndLongState.class, Block.class, int.class);
Expand Down Expand Up @@ -156,25 +151,44 @@ public static void inputShortDecimal(LongDecimalWithOverflowAndLongState state,
{
state.addLong(1); // row counter

Slice sum = state.getLongDecimal();
if (sum == null) {
sum = UnscaledDecimal128Arithmetic.unscaledDecimal();
state.setLongDecimal(sum);
state.setNotNull();

long[] decimal = state.getDecimalArray();
int offset = state.getDecimalArrayOffset();

long rightLow = block.getLong(position, 0);
long rightHigh = 0;
if (rightLow < 0) {
rightLow = -rightLow;
rightHigh = SIGN_LONG_MASK;
}
long overflow = UnscaledDecimal128Arithmetic.addWithOverflow(sum, UnscaledDecimal128Arithmetic.unscaledDecimal(SHORT_DECIMAL_TYPE.getLong(block, position)), sum);

long overflow = addWithOverflow(
decimal[offset],
decimal[offset + 1],
rightLow,
rightHigh,
decimal,
offset);
state.addOverflow(overflow);
}

public static void inputLongDecimal(LongDecimalWithOverflowAndLongState state, Block block, int position)
{
state.addLong(1); // row counter

Slice sum = state.getLongDecimal();
if (sum == null) {
sum = UnscaledDecimal128Arithmetic.unscaledDecimal();
state.setLongDecimal(sum);
}
long overflow = UnscaledDecimal128Arithmetic.addWithOverflow(sum, LONG_DECIMAL_TYPE.getSlice(block, position), sum);
state.setNotNull();

long[] decimal = state.getDecimalArray();
int offset = state.getDecimalArrayOffset();

long overflow = addWithOverflow(
decimal[offset],
decimal[offset + 1],
block.getLong(position, 0),
block.getLong(position, SIZE_OF_LONG),
decimal,
offset);
state.addOverflow(overflow);
}

Expand All @@ -184,12 +198,25 @@ public static void combine(LongDecimalWithOverflowAndLongState state, LongDecima

long overflow = otherState.getOverflow();

Slice sum = state.getLongDecimal();
if (sum == null) {
state.setLongDecimal(otherState.getLongDecimal());
long[] decimal = state.getDecimalArray();
int offset = state.getDecimalArrayOffset();

long[] otherDecimal = otherState.getDecimalArray();
int otherOffset = otherState.getDecimalArrayOffset();

if (state.isNotNull()) {
overflow += addWithOverflow(
decimal[offset],
decimal[offset + 1],
otherDecimal[otherOffset],
otherDecimal[otherOffset + 1],
decimal,
offset);
}
else {
overflow += UnscaledDecimal128Arithmetic.addWithOverflow(sum, otherState.getLongDecimal(), sum);
state.setNotNull();
decimal[offset] = otherDecimal[otherOffset];
decimal[offset + 1] = otherDecimal[otherOffset + 1];
}

state.addOverflow(overflow);
Expand Down Expand Up @@ -218,7 +245,9 @@ public static void outputLongDecimal(DecimalType type, LongDecimalWithOverflowAn
@VisibleForTesting
public static BigDecimal average(LongDecimalWithOverflowAndLongState state, DecimalType type)
{
BigDecimal sum = new BigDecimal(Decimals.decodeUnscaledValue(state.getLongDecimal()), type.getScale());
long[] decimal = state.getDecimalArray();
int offset = state.getDecimalArrayOffset();
BigDecimal sum = new BigDecimal(unscaledDecimalToBigInteger(decimal[offset], decimal[offset + 1]), type.getScale());
BigDecimal count = BigDecimal.valueOf(state.getLong());

long overflow = state.getOverflow();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import com.google.common.collect.ImmutableList;
import io.airlift.bytecode.DynamicClassLoader;
import io.airlift.slice.Slice;
import io.trino.metadata.AggregationFunctionMetadata;
import io.trino.metadata.FunctionArgumentDefinition;
import io.trino.metadata.FunctionBinding;
Expand All @@ -33,7 +32,6 @@
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeSignature;
import io.trino.spi.type.UnscaledDecimal128Arithmetic;

import java.lang.invoke.MethodHandle;
import java.util.List;
Expand All @@ -42,29 +40,25 @@
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Iterables.getOnlyElement;
import static io.airlift.slice.SizeOf.SIZE_OF_LONG;
import static io.trino.metadata.FunctionKind.AGGREGATE;
import static io.trino.operator.aggregation.AggregationMetadata.ParameterMetadata;
import static io.trino.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INDEX;
import static io.trino.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INPUT_CHANNEL;
import static io.trino.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.STATE;
import static io.trino.operator.aggregation.AggregationUtils.generateAggregationName;
import static io.trino.spi.type.Decimals.MAX_PRECISION;
import static io.trino.spi.type.Decimals.MAX_SHORT_PRECISION;
import static io.trino.spi.type.TypeSignatureParameter.numericParameter;
import static io.trino.spi.type.TypeSignatureParameter.typeVariable;
import static io.trino.spi.type.UnscaledDecimal128Arithmetic.SIGN_LONG_MASK;
import static io.trino.spi.type.UnscaledDecimal128Arithmetic.addWithOverflow;
import static io.trino.spi.type.UnscaledDecimal128Arithmetic.throwIfOverflows;
import static io.trino.spi.type.UnscaledDecimal128Arithmetic.throwOverflowException;
import static io.trino.spi.type.UnscaledDecimal128Arithmetic.unscaledDecimal;
import static io.trino.spi.type.VarbinaryType.VARBINARY;
import static io.trino.util.Reflection.methodHandle;

public class DecimalSumAggregation
extends SqlAggregationFunction
{
// Constant references for short/long decimal types for use in operations that only manipulate unscaled values
private static final DecimalType LONG_DECIMAL_TYPE = DecimalType.createDecimalType(MAX_PRECISION, 0);
private static final DecimalType SHORT_DECIMAL_TYPE = DecimalType.createDecimalType(MAX_SHORT_PRECISION, 0);

public static final DecimalSumAggregation DECIMAL_SUM_AGGREGATION = new DecimalSumAggregation();
private static final String NAME = "sum";
private static final MethodHandle SHORT_DECIMAL_INPUT_FUNCTION = methodHandle(DecimalSumAggregation.class, "inputShortDecimal", LongDecimalWithOverflowState.class, Block.class, int.class);
Expand Down Expand Up @@ -142,53 +136,93 @@ private static List<ParameterMetadata> createInputParameterMetadata(Type type)

public static void inputShortDecimal(LongDecimalWithOverflowState state, Block block, int position)
{
Slice sum = state.getLongDecimal();
if (sum == null) {
sum = UnscaledDecimal128Arithmetic.unscaledDecimal();
state.setLongDecimal(sum);
state.setNotNull();

long[] decimal = state.getDecimalArray();
int offset = state.getDecimalArrayOffset();

long rightLow = block.getLong(position, 0);
long rightHigh = 0;
if (rightLow < 0) {
rightLow = -rightLow;
rightHigh = SIGN_LONG_MASK;
}
long overflow = UnscaledDecimal128Arithmetic.addWithOverflow(sum, unscaledDecimal(SHORT_DECIMAL_TYPE.getLong(block, position)), sum);

long overflow = addWithOverflow(
decimal[offset],
decimal[offset + 1],
rightLow,
rightHigh,
decimal,
offset);
state.addOverflow(overflow);
}

public static void inputLongDecimal(LongDecimalWithOverflowState state, Block block, int position)
{
Slice sum = state.getLongDecimal();
if (sum == null) {
sum = UnscaledDecimal128Arithmetic.unscaledDecimal();
state.setLongDecimal(sum);
}
long overflow = UnscaledDecimal128Arithmetic.addWithOverflow(sum, LONG_DECIMAL_TYPE.getSlice(block, position), sum);
state.setNotNull();

long[] decimal = state.getDecimalArray();
int offset = state.getDecimalArrayOffset();

long overflow = addWithOverflow(
decimal[offset],
decimal[offset + 1],
block.getLong(position, 0),
block.getLong(position, SIZE_OF_LONG),
decimal,
offset);
state.addOverflow(overflow);
}

public static void combine(LongDecimalWithOverflowState state, LongDecimalWithOverflowState otherState)
{
Slice sum = state.getLongDecimal();
long overflow = otherState.getOverflow();

if (sum == null) {
state.setLongDecimal(otherState.getLongDecimal());
long[] decimal = state.getDecimalArray();
int offset = state.getDecimalArrayOffset();

long[] otherDecimal = otherState.getDecimalArray();
int otherOffset = otherState.getDecimalArrayOffset();

if (state.isNotNull()) {
overflow += addWithOverflow(
decimal[offset],
decimal[offset + 1],
otherDecimal[otherOffset],
otherDecimal[otherOffset + 1],
decimal,
offset);
}
else {
overflow += UnscaledDecimal128Arithmetic.addWithOverflow(sum, otherState.getLongDecimal(), sum);
state.setNotNull();
decimal[offset] = otherDecimal[otherOffset];
decimal[offset + 1] = otherDecimal[otherOffset + 1];
}

state.addOverflow(overflow);
}

public static void outputLongDecimal(LongDecimalWithOverflowState state, BlockBuilder out)
{
Slice decimal = state.getLongDecimal();
if (decimal == null) {
out.appendNull();
}
else {
if (state.isNotNull()) {
if (state.getOverflow() != 0) {
throwOverflowException();
}
throwIfOverflows(decimal);
LONG_DECIMAL_TYPE.writeSlice(out, decimal);

long[] decimal = state.getDecimalArray();
int offset = state.getDecimalArrayOffset();

long rawLow = decimal[offset];
long rawHigh = decimal[offset + 1];

throwIfOverflows(rawLow, rawHigh);
out.writeLong(rawLow);
out.writeLong(rawHigh);
out.closeEntry();
}
else {
out.appendNull();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public void addLong(long value)
@Override
public long getEstimatedSize()
{
return INSTANCE_SIZE + unscaledDecimals.sizeOf() + (numberOfElements * SingleLongDecimalWithOverflowAndLongState.SIZE) + (overflows == null ? 0 : overflows.sizeOf());
return INSTANCE_SIZE + isNotNull.sizeOf() + unscaledDecimals.sizeOf() + (overflows == null ? 0 : overflows.sizeOf());
}
}

Expand Down Expand Up @@ -112,9 +112,6 @@ public void addLong(long value)
@Override
public long getEstimatedSize()
{
if (getLongDecimal() == null) {
return INSTANCE_SIZE;
}
return INSTANCE_SIZE + SIZE;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,15 @@ public Type getSerializedType()
@Override
public void serialize(LongDecimalWithOverflowAndLongState state, BlockBuilder out)
{
Slice decimal = state.getLongDecimal();
if (decimal == null) {
out.appendNull();
}
else {
if (state.isNotNull()) {
long count = state.getLong();
long overflow = state.getOverflow();
VARBINARY.writeSlice(out, Slices.wrappedLongArray(count, overflow, decimal.getLong(0), decimal.getLong(Long.BYTES)));
long[] decimal = state.getDecimalArray();
int offset = state.getDecimalArrayOffset();
VARBINARY.writeSlice(out, Slices.wrappedLongArray(count, overflow, decimal[offset], decimal[offset + 1]));
}
else {
out.appendNull();
}
}

Expand All @@ -59,11 +60,14 @@ public void deserialize(Block block, int index, LongDecimalWithOverflowAndLongSt

long count = slice.getLong(0);
long overflow = slice.getLong(Long.BYTES);
Slice decimal = Slices.wrappedLongArray(slice.getLong(Long.BYTES * 2), slice.getLong(Long.BYTES * 3));

state.setLong(count);
state.setOverflow(overflow);
state.setLongDecimal(decimal);
state.setNotNull();
long[] decimal = state.getDecimalArray();
int offset = state.getDecimalArrayOffset();
decimal[offset] = slice.getLong(Long.BYTES * 2);
decimal[offset + 1] = slice.getLong(Long.BYTES * 3);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,20 @@
*/
package io.trino.operator.aggregation.state;

import io.airlift.slice.Slice;
import io.trino.spi.function.AccumulatorState;
import io.trino.spi.function.AccumulatorStateMetadata;

@AccumulatorStateMetadata(stateFactoryClass = LongDecimalWithOverflowStateFactory.class, stateSerializerClass = LongDecimalWithOverflowStateSerializer.class)
public interface LongDecimalWithOverflowState
extends AccumulatorState
{
Slice getLongDecimal();
boolean isNotNull();

void setLongDecimal(Slice unscaledDecimal);
void setNotNull();

long[] getDecimalArray();

int getDecimalArrayOffset();

long getOverflow();

Expand Down
Loading

0 comments on commit b524793

Please sign in to comment.