From b524793f8427125de19c3ccaf07c885c7d0be8fa Mon Sep 17 00:00:00 2001 From: Karol Sobczak Date: Mon, 18 Oct 2021 12:58:46 +0200 Subject: [PATCH] Flatten decimal accumulator MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This improves decimal aggregations performance BEFORE: Benchmark (function) (groupCount) (type) Mode Cnt Score Error Units BenchmarkDecimalAggregation.benchmark sum 10 LONG avgt 10 11,782 ± 0,511 ns/op BenchmarkDecimalAggregation.benchmark sum 10000 LONG avgt 10 20,122 ± 0,773 ns/op BenchmarkDecimalAggregation.benchmarkEvaluateFinal sum 10000 LONG avgt 10 727254,852 ± 21119,807 ns/op BenchmarkDecimalAggregation.benchmarkEvaluateIntermediate sum 10 LONG avgt 10 11,628 ± 0,175 ns/op BenchmarkDecimalAggregation.benchmarkEvaluateIntermediate sum 10000 LONG avgt 10 20,417 ± 0,594 ns/op AFTER Benchmark (function) (groupCount) (type) Mode Cnt Score Error Units BenchmarkDecimalAggregation.benchmark sum 10 LONG avgt 10 5,488 ± 0,466 ns/op BenchmarkDecimalAggregation.benchmark sum 10000 LONG avgt 10 6,471 ± 0,664 ns/op BenchmarkDecimalAggregation.benchmarkEvaluateFinal sum 10000 LONG avgt 10 552729,926 ± 23429,821 ns/op BenchmarkDecimalAggregation.benchmarkEvaluateIntermediate sum 10 LONG avgt 10 5,236 ± 0,279 ns/op BenchmarkDecimalAggregation.benchmarkEvaluateIntermediate sum 10000 LONG avgt 10 6,462 ± 0,323 ns/opfixup --- .../DecimalAverageAggregation.java | 79 ++++++--- .../aggregation/DecimalSumAggregation.java | 96 +++++++---- ...ecimalWithOverflowAndLongStateFactory.java | 5 +- ...malWithOverflowAndLongStateSerializer.java | 20 ++- .../state/LongDecimalWithOverflowState.java | 9 +- .../LongDecimalWithOverflowStateFactory.java | 72 ++++++--- ...ongDecimalWithOverflowStateSerializer.java | 18 ++- .../TestDecimalAverageAggregation.java | 27 +++- .../TestDecimalSumAggregation.java | 26 ++- .../type/UnscaledDecimal128Arithmetic.java | 152 +++++++++++++++++- .../TestUnscaledDecimal128Arithmetic.java | 35 +++- .../java/io/trino/array/BooleanBigArray.java | 5 + .../java/io/trino/array/LongBigArray.java | 10 ++ 13 files changed, 432 insertions(+), 122 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/DecimalAverageAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/DecimalAverageAggregation.java index ec9219592512..4d71195f41d3 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/DecimalAverageAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/DecimalAverageAggregation.java @@ -16,7 +16,6 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import io.airlift.bytecode.DynamicClassLoader; -import io.airlift.slice.Slice; import io.trino.metadata.AggregationFunctionMetadata; import io.trino.metadata.FunctionArgumentDefinition; import io.trino.metadata.FunctionBinding; @@ -32,10 +31,8 @@ import io.trino.spi.function.AccumulatorState; import io.trino.spi.function.AccumulatorStateSerializer; import io.trino.spi.type.DecimalType; -import io.trino.spi.type.Decimals; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; -import io.trino.spi.type.UnscaledDecimal128Arithmetic; import java.lang.invoke.MethodHandle; import java.math.BigDecimal; @@ -46,18 +43,20 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Iterables.getOnlyElement; +import static io.airlift.slice.SizeOf.SIZE_OF_LONG; import static io.trino.metadata.FunctionKind.AGGREGATE; import static io.trino.operator.aggregation.AggregationMetadata.ParameterMetadata; import static io.trino.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INDEX; import static io.trino.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INPUT_CHANNEL; import static io.trino.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.STATE; import static io.trino.operator.aggregation.AggregationUtils.generateAggregationName; -import static io.trino.spi.type.Decimals.MAX_PRECISION; -import static io.trino.spi.type.Decimals.MAX_SHORT_PRECISION; import static io.trino.spi.type.Decimals.writeBigDecimal; import static io.trino.spi.type.Decimals.writeShortDecimal; import static io.trino.spi.type.TypeSignatureParameter.typeVariable; +import static io.trino.spi.type.UnscaledDecimal128Arithmetic.SIGN_LONG_MASK; import static io.trino.spi.type.UnscaledDecimal128Arithmetic.UNSCALED_DECIMAL_128_SLICE_LENGTH; +import static io.trino.spi.type.UnscaledDecimal128Arithmetic.addWithOverflow; +import static io.trino.spi.type.UnscaledDecimal128Arithmetic.unscaledDecimalToBigInteger; import static io.trino.spi.type.VarbinaryType.VARBINARY; import static io.trino.util.Reflection.methodHandle; import static java.math.BigDecimal.ROUND_HALF_UP; @@ -67,10 +66,6 @@ public class DecimalAverageAggregation { public static final DecimalAverageAggregation DECIMAL_AVERAGE_AGGREGATION = new DecimalAverageAggregation(); - // Constant references for short/long decimal types for use in operations that only manipulate unscaled values - private static final DecimalType LONG_DECIMAL_TYPE = DecimalType.createDecimalType(MAX_PRECISION, 0); - private static final DecimalType SHORT_DECIMAL_TYPE = DecimalType.createDecimalType(MAX_SHORT_PRECISION, 0); - private static final String NAME = "avg"; private static final MethodHandle SHORT_DECIMAL_INPUT_FUNCTION = methodHandle(DecimalAverageAggregation.class, "inputShortDecimal", LongDecimalWithOverflowAndLongState.class, Block.class, int.class); private static final MethodHandle LONG_DECIMAL_INPUT_FUNCTION = methodHandle(DecimalAverageAggregation.class, "inputLongDecimal", LongDecimalWithOverflowAndLongState.class, Block.class, int.class); @@ -156,12 +151,25 @@ public static void inputShortDecimal(LongDecimalWithOverflowAndLongState state, { state.addLong(1); // row counter - Slice sum = state.getLongDecimal(); - if (sum == null) { - sum = UnscaledDecimal128Arithmetic.unscaledDecimal(); - state.setLongDecimal(sum); + state.setNotNull(); + + long[] decimal = state.getDecimalArray(); + int offset = state.getDecimalArrayOffset(); + + long rightLow = block.getLong(position, 0); + long rightHigh = 0; + if (rightLow < 0) { + rightLow = -rightLow; + rightHigh = SIGN_LONG_MASK; } - long overflow = UnscaledDecimal128Arithmetic.addWithOverflow(sum, UnscaledDecimal128Arithmetic.unscaledDecimal(SHORT_DECIMAL_TYPE.getLong(block, position)), sum); + + long overflow = addWithOverflow( + decimal[offset], + decimal[offset + 1], + rightLow, + rightHigh, + decimal, + offset); state.addOverflow(overflow); } @@ -169,12 +177,18 @@ public static void inputLongDecimal(LongDecimalWithOverflowAndLongState state, B { state.addLong(1); // row counter - Slice sum = state.getLongDecimal(); - if (sum == null) { - sum = UnscaledDecimal128Arithmetic.unscaledDecimal(); - state.setLongDecimal(sum); - } - long overflow = UnscaledDecimal128Arithmetic.addWithOverflow(sum, LONG_DECIMAL_TYPE.getSlice(block, position), sum); + state.setNotNull(); + + long[] decimal = state.getDecimalArray(); + int offset = state.getDecimalArrayOffset(); + + long overflow = addWithOverflow( + decimal[offset], + decimal[offset + 1], + block.getLong(position, 0), + block.getLong(position, SIZE_OF_LONG), + decimal, + offset); state.addOverflow(overflow); } @@ -184,12 +198,25 @@ public static void combine(LongDecimalWithOverflowAndLongState state, LongDecima long overflow = otherState.getOverflow(); - Slice sum = state.getLongDecimal(); - if (sum == null) { - state.setLongDecimal(otherState.getLongDecimal()); + long[] decimal = state.getDecimalArray(); + int offset = state.getDecimalArrayOffset(); + + long[] otherDecimal = otherState.getDecimalArray(); + int otherOffset = otherState.getDecimalArrayOffset(); + + if (state.isNotNull()) { + overflow += addWithOverflow( + decimal[offset], + decimal[offset + 1], + otherDecimal[otherOffset], + otherDecimal[otherOffset + 1], + decimal, + offset); } else { - overflow += UnscaledDecimal128Arithmetic.addWithOverflow(sum, otherState.getLongDecimal(), sum); + state.setNotNull(); + decimal[offset] = otherDecimal[otherOffset]; + decimal[offset + 1] = otherDecimal[otherOffset + 1]; } state.addOverflow(overflow); @@ -218,7 +245,9 @@ public static void outputLongDecimal(DecimalType type, LongDecimalWithOverflowAn @VisibleForTesting public static BigDecimal average(LongDecimalWithOverflowAndLongState state, DecimalType type) { - BigDecimal sum = new BigDecimal(Decimals.decodeUnscaledValue(state.getLongDecimal()), type.getScale()); + long[] decimal = state.getDecimalArray(); + int offset = state.getDecimalArrayOffset(); + BigDecimal sum = new BigDecimal(unscaledDecimalToBigInteger(decimal[offset], decimal[offset + 1]), type.getScale()); BigDecimal count = BigDecimal.valueOf(state.getLong()); long overflow = state.getOverflow(); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/DecimalSumAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/DecimalSumAggregation.java index 2780f2d9b879..0d878667c76e 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/DecimalSumAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/DecimalSumAggregation.java @@ -15,7 +15,6 @@ import com.google.common.collect.ImmutableList; import io.airlift.bytecode.DynamicClassLoader; -import io.airlift.slice.Slice; import io.trino.metadata.AggregationFunctionMetadata; import io.trino.metadata.FunctionArgumentDefinition; import io.trino.metadata.FunctionBinding; @@ -33,7 +32,6 @@ import io.trino.spi.type.DecimalType; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; -import io.trino.spi.type.UnscaledDecimal128Arithmetic; import java.lang.invoke.MethodHandle; import java.util.List; @@ -42,29 +40,25 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Iterables.getOnlyElement; +import static io.airlift.slice.SizeOf.SIZE_OF_LONG; import static io.trino.metadata.FunctionKind.AGGREGATE; import static io.trino.operator.aggregation.AggregationMetadata.ParameterMetadata; import static io.trino.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INDEX; import static io.trino.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INPUT_CHANNEL; import static io.trino.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.STATE; import static io.trino.operator.aggregation.AggregationUtils.generateAggregationName; -import static io.trino.spi.type.Decimals.MAX_PRECISION; -import static io.trino.spi.type.Decimals.MAX_SHORT_PRECISION; import static io.trino.spi.type.TypeSignatureParameter.numericParameter; import static io.trino.spi.type.TypeSignatureParameter.typeVariable; +import static io.trino.spi.type.UnscaledDecimal128Arithmetic.SIGN_LONG_MASK; +import static io.trino.spi.type.UnscaledDecimal128Arithmetic.addWithOverflow; import static io.trino.spi.type.UnscaledDecimal128Arithmetic.throwIfOverflows; import static io.trino.spi.type.UnscaledDecimal128Arithmetic.throwOverflowException; -import static io.trino.spi.type.UnscaledDecimal128Arithmetic.unscaledDecimal; import static io.trino.spi.type.VarbinaryType.VARBINARY; import static io.trino.util.Reflection.methodHandle; public class DecimalSumAggregation extends SqlAggregationFunction { - // Constant references for short/long decimal types for use in operations that only manipulate unscaled values - private static final DecimalType LONG_DECIMAL_TYPE = DecimalType.createDecimalType(MAX_PRECISION, 0); - private static final DecimalType SHORT_DECIMAL_TYPE = DecimalType.createDecimalType(MAX_SHORT_PRECISION, 0); - public static final DecimalSumAggregation DECIMAL_SUM_AGGREGATION = new DecimalSumAggregation(); private static final String NAME = "sum"; private static final MethodHandle SHORT_DECIMAL_INPUT_FUNCTION = methodHandle(DecimalSumAggregation.class, "inputShortDecimal", LongDecimalWithOverflowState.class, Block.class, int.class); @@ -142,36 +136,68 @@ private static List createInputParameterMetadata(Type type) public static void inputShortDecimal(LongDecimalWithOverflowState state, Block block, int position) { - Slice sum = state.getLongDecimal(); - if (sum == null) { - sum = UnscaledDecimal128Arithmetic.unscaledDecimal(); - state.setLongDecimal(sum); + state.setNotNull(); + + long[] decimal = state.getDecimalArray(); + int offset = state.getDecimalArrayOffset(); + + long rightLow = block.getLong(position, 0); + long rightHigh = 0; + if (rightLow < 0) { + rightLow = -rightLow; + rightHigh = SIGN_LONG_MASK; } - long overflow = UnscaledDecimal128Arithmetic.addWithOverflow(sum, unscaledDecimal(SHORT_DECIMAL_TYPE.getLong(block, position)), sum); + + long overflow = addWithOverflow( + decimal[offset], + decimal[offset + 1], + rightLow, + rightHigh, + decimal, + offset); state.addOverflow(overflow); } public static void inputLongDecimal(LongDecimalWithOverflowState state, Block block, int position) { - Slice sum = state.getLongDecimal(); - if (sum == null) { - sum = UnscaledDecimal128Arithmetic.unscaledDecimal(); - state.setLongDecimal(sum); - } - long overflow = UnscaledDecimal128Arithmetic.addWithOverflow(sum, LONG_DECIMAL_TYPE.getSlice(block, position), sum); + state.setNotNull(); + + long[] decimal = state.getDecimalArray(); + int offset = state.getDecimalArrayOffset(); + + long overflow = addWithOverflow( + decimal[offset], + decimal[offset + 1], + block.getLong(position, 0), + block.getLong(position, SIZE_OF_LONG), + decimal, + offset); state.addOverflow(overflow); } public static void combine(LongDecimalWithOverflowState state, LongDecimalWithOverflowState otherState) { - Slice sum = state.getLongDecimal(); long overflow = otherState.getOverflow(); - if (sum == null) { - state.setLongDecimal(otherState.getLongDecimal()); + long[] decimal = state.getDecimalArray(); + int offset = state.getDecimalArrayOffset(); + + long[] otherDecimal = otherState.getDecimalArray(); + int otherOffset = otherState.getDecimalArrayOffset(); + + if (state.isNotNull()) { + overflow += addWithOverflow( + decimal[offset], + decimal[offset + 1], + otherDecimal[otherOffset], + otherDecimal[otherOffset + 1], + decimal, + offset); } else { - overflow += UnscaledDecimal128Arithmetic.addWithOverflow(sum, otherState.getLongDecimal(), sum); + state.setNotNull(); + decimal[offset] = otherDecimal[otherOffset]; + decimal[offset + 1] = otherDecimal[otherOffset + 1]; } state.addOverflow(overflow); @@ -179,16 +205,24 @@ public static void combine(LongDecimalWithOverflowState state, LongDecimalWithOv public static void outputLongDecimal(LongDecimalWithOverflowState state, BlockBuilder out) { - Slice decimal = state.getLongDecimal(); - if (decimal == null) { - out.appendNull(); - } - else { + if (state.isNotNull()) { if (state.getOverflow() != 0) { throwOverflowException(); } - throwIfOverflows(decimal); - LONG_DECIMAL_TYPE.writeSlice(out, decimal); + + long[] decimal = state.getDecimalArray(); + int offset = state.getDecimalArrayOffset(); + + long rawLow = decimal[offset]; + long rawHigh = decimal[offset + 1]; + + throwIfOverflows(rawLow, rawHigh); + out.writeLong(rawLow); + out.writeLong(rawHigh); + out.closeEntry(); + } + else { + out.appendNull(); } } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowAndLongStateFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowAndLongStateFactory.java index 8ad247c12b32..e43dc0a0c049 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowAndLongStateFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowAndLongStateFactory.java @@ -79,7 +79,7 @@ public void addLong(long value) @Override public long getEstimatedSize() { - return INSTANCE_SIZE + unscaledDecimals.sizeOf() + (numberOfElements * SingleLongDecimalWithOverflowAndLongState.SIZE) + (overflows == null ? 0 : overflows.sizeOf()); + return INSTANCE_SIZE + isNotNull.sizeOf() + unscaledDecimals.sizeOf() + (overflows == null ? 0 : overflows.sizeOf()); } } @@ -112,9 +112,6 @@ public void addLong(long value) @Override public long getEstimatedSize() { - if (getLongDecimal() == null) { - return INSTANCE_SIZE; - } return INSTANCE_SIZE + SIZE; } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowAndLongStateSerializer.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowAndLongStateSerializer.java index e609b1c6e640..cc732b16d851 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowAndLongStateSerializer.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowAndLongStateSerializer.java @@ -37,14 +37,15 @@ public Type getSerializedType() @Override public void serialize(LongDecimalWithOverflowAndLongState state, BlockBuilder out) { - Slice decimal = state.getLongDecimal(); - if (decimal == null) { - out.appendNull(); - } - else { + if (state.isNotNull()) { long count = state.getLong(); long overflow = state.getOverflow(); - VARBINARY.writeSlice(out, Slices.wrappedLongArray(count, overflow, decimal.getLong(0), decimal.getLong(Long.BYTES))); + long[] decimal = state.getDecimalArray(); + int offset = state.getDecimalArrayOffset(); + VARBINARY.writeSlice(out, Slices.wrappedLongArray(count, overflow, decimal[offset], decimal[offset + 1])); + } + else { + out.appendNull(); } } @@ -59,11 +60,14 @@ public void deserialize(Block block, int index, LongDecimalWithOverflowAndLongSt long count = slice.getLong(0); long overflow = slice.getLong(Long.BYTES); - Slice decimal = Slices.wrappedLongArray(slice.getLong(Long.BYTES * 2), slice.getLong(Long.BYTES * 3)); state.setLong(count); state.setOverflow(overflow); - state.setLongDecimal(decimal); + state.setNotNull(); + long[] decimal = state.getDecimalArray(); + int offset = state.getDecimalArrayOffset(); + decimal[offset] = slice.getLong(Long.BYTES * 2); + decimal[offset + 1] = slice.getLong(Long.BYTES * 3); } } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowState.java index bd55bfe30923..bae8584c2a0e 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowState.java @@ -13,7 +13,6 @@ */ package io.trino.operator.aggregation.state; -import io.airlift.slice.Slice; import io.trino.spi.function.AccumulatorState; import io.trino.spi.function.AccumulatorStateMetadata; @@ -21,9 +20,13 @@ public interface LongDecimalWithOverflowState extends AccumulatorState { - Slice getLongDecimal(); + boolean isNotNull(); - void setLongDecimal(Slice unscaledDecimal); + void setNotNull(); + + long[] getDecimalArray(); + + int getDecimalArrayOffset(); long getOverflow(); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowStateFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowStateFactory.java index bba66a865f26..f3511817e0f8 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowStateFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowStateFactory.java @@ -13,16 +13,14 @@ */ package io.trino.operator.aggregation.state; -import io.airlift.slice.Slice; +import io.trino.array.BooleanBigArray; import io.trino.array.LongBigArray; -import io.trino.array.ObjectBigArray; import io.trino.spi.function.AccumulatorStateFactory; import org.openjdk.jol.info.ClassLayout; import javax.annotation.Nullable; -import static io.trino.spi.type.UnscaledDecimal128Arithmetic.UNSCALED_DECIMAL_128_SLICE_LENGTH; -import static java.util.Objects.requireNonNull; +import static io.airlift.slice.SizeOf.sizeOf; public class LongDecimalWithOverflowStateFactory implements AccumulatorStateFactory @@ -56,34 +54,46 @@ public static class GroupedLongDecimalWithOverflowState implements LongDecimalWithOverflowState { private static final int INSTANCE_SIZE = ClassLayout.parseClass(GroupedLongDecimalWithOverflowState.class).instanceSize(); - protected final ObjectBigArray unscaledDecimals = new ObjectBigArray<>(); + protected final BooleanBigArray isNotNull = new BooleanBigArray(); + /** + * Stores 128-bit decimals as pairs of longs + */ + protected final LongBigArray unscaledDecimals = new LongBigArray(); @Nullable protected LongBigArray overflows; // lazily initialized on the first overflow - protected long numberOfElements; @Override public void ensureCapacity(long size) { - unscaledDecimals.ensureCapacity(size); + isNotNull.ensureCapacity(size); + unscaledDecimals.ensureCapacity(size * 2); if (overflows != null) { overflows.ensureCapacity(size); } } @Override - public Slice getLongDecimal() + public boolean isNotNull() { - return unscaledDecimals.get(getGroupId()); + return isNotNull.get(getGroupId()); } @Override - public void setLongDecimal(Slice value) + public void setNotNull() { - requireNonNull(value, "value is null"); - boolean existed = unscaledDecimals.replace(getGroupId(), value); - if (!existed) { - numberOfElements++; - } + isNotNull.set(getGroupId(), true); + } + + @Override + public long[] getDecimalArray() + { + return unscaledDecimals.getSegment(getGroupId() * 2); + } + + @Override + public int getDecimalArrayOffset() + { + return unscaledDecimals.getOffset(getGroupId() * 2); } @Override @@ -105,7 +115,7 @@ public void setOverflow(long overflow) long groupId = getGroupId(); if (overflows == null) { overflows = new LongBigArray(); - overflows.ensureCapacity(unscaledDecimals.getCapacity()); + overflows.ensureCapacity(isNotNull.getCapacity()); } overflows.set(groupId, overflow); } @@ -117,7 +127,7 @@ public void addOverflow(long overflow) long groupId = getGroupId(); if (overflows == null) { overflows = new LongBigArray(); - overflows.ensureCapacity(unscaledDecimals.getCapacity()); + overflows.ensureCapacity(isNotNull.getCapacity()); } overflows.add(groupId, overflow); } @@ -126,7 +136,7 @@ public void addOverflow(long overflow) @Override public long getEstimatedSize() { - return INSTANCE_SIZE + unscaledDecimals.sizeOf() + (numberOfElements * SingleLongDecimalWithOverflowState.SIZE) + (overflows == null ? 0 : overflows.sizeOf()); + return INSTANCE_SIZE + isNotNull.sizeOf() + unscaledDecimals.sizeOf() + (overflows == null ? 0 : overflows.sizeOf()); } } @@ -134,21 +144,34 @@ public static class SingleLongDecimalWithOverflowState implements LongDecimalWithOverflowState { private static final int INSTANCE_SIZE = ClassLayout.parseClass(SingleLongDecimalWithOverflowState.class).instanceSize(); - public static final int SIZE = ClassLayout.parseClass(Slice.class).instanceSize() + UNSCALED_DECIMAL_128_SLICE_LENGTH; + protected static final int SIZE = (int) sizeOf(new long[2]); - protected Slice unscaledDecimal; + protected final long[] unscaledDecimal = new long[2]; + protected boolean isNotNull; protected long overflow; @Override - public Slice getLongDecimal() + public boolean isNotNull() + { + return isNotNull; + } + + @Override + public void setNotNull() + { + isNotNull = true; + } + + @Override + public long[] getDecimalArray() { return unscaledDecimal; } @Override - public void setLongDecimal(Slice unscaledDecimal) + public int getDecimalArrayOffset() { - this.unscaledDecimal = unscaledDecimal; + return 0; } @Override @@ -172,9 +195,6 @@ public void addOverflow(long overflow) @Override public long getEstimatedSize() { - if (getLongDecimal() == null) { - return INSTANCE_SIZE; - } return INSTANCE_SIZE + SIZE; } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowStateSerializer.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowStateSerializer.java index ef7c14a9d5e0..04d6fb1c5c24 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowStateSerializer.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowStateSerializer.java @@ -37,13 +37,14 @@ public Type getSerializedType() @Override public void serialize(LongDecimalWithOverflowState state, BlockBuilder out) { - Slice decimal = state.getLongDecimal(); - if (decimal == null) { - out.appendNull(); + if (state.isNotNull()) { + long overflow = state.getOverflow(); + long[] decimal = state.getDecimalArray(); + int offset = state.getDecimalArrayOffset(); + VARBINARY.writeSlice(out, Slices.wrappedLongArray(overflow, decimal[offset], decimal[offset + 1])); } else { - long overflow = state.getOverflow(); - VARBINARY.writeSlice(out, Slices.wrappedLongArray(overflow, decimal.getLong(0), decimal.getLong(Long.BYTES))); + out.appendNull(); } } @@ -57,10 +58,13 @@ public void deserialize(Block block, int index, LongDecimalWithOverflowState sta } long overflow = slice.getLong(0); - Slice decimal = Slices.wrappedLongArray(slice.getLong(Long.BYTES), slice.getLong(Long.BYTES * 2)); state.setOverflow(overflow); - state.setLongDecimal(decimal); + state.setNotNull(); + long[] decimal = state.getDecimalArray(); + int offset = state.getDecimalArrayOffset(); + decimal[offset] = slice.getLong(Long.BYTES); + decimal[offset + 1] = slice.getLong(Long.BYTES * 2); } } } diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDecimalAverageAggregation.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDecimalAverageAggregation.java index bdbb7f218107..e95156ab9735 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDecimalAverageAggregation.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDecimalAverageAggregation.java @@ -13,8 +13,11 @@ */ package io.trino.operator.aggregation; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; import io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongState; import io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongStateFactory; +import io.trino.operator.aggregation.state.LongDecimalWithOverflowState; import io.trino.spi.block.BlockBuilder; import io.trino.spi.type.DecimalType; import io.trino.spi.type.UnscaledDecimal128Arithmetic; @@ -50,13 +53,13 @@ public void testOverflow() assertEquals(state.getLong(), 1); assertEquals(state.getOverflow(), 0); - assertEquals(state.getLongDecimal(), unscaledDecimal(TWO.pow(126))); + assertEquals(getDecimalSlice(state), unscaledDecimal(TWO.pow(126))); addToState(state, TWO.pow(126)); assertEquals(state.getLong(), 2); assertEquals(state.getOverflow(), 1); - assertEquals(state.getLongDecimal(), unscaledDecimal(0)); + assertEquals(getDecimalSlice(state), unscaledDecimal(0)); assertEquals(average(state, TYPE), new BigDecimal(TWO.pow(126))); } @@ -68,13 +71,13 @@ public void testUnderflow() assertEquals(state.getLong(), 1); assertEquals(state.getOverflow(), 0); - assertEquals(state.getLongDecimal(), unscaledDecimal(TWO.pow(126).negate())); + assertEquals(getDecimalSlice(state), unscaledDecimal(TWO.pow(126).negate())); addToState(state, TWO.pow(126).negate()); assertEquals(state.getLong(), 2); assertEquals(state.getOverflow(), -1); - assertEquals(UnscaledDecimal128Arithmetic.compare(state.getLongDecimal(), unscaledDecimal(0)), 0); + assertEquals(UnscaledDecimal128Arithmetic.compare(getDecimalSlice(state), unscaledDecimal(0)), 0); assertEquals(average(state, TYPE), new BigDecimal(TWO.pow(126).negate())); } @@ -87,14 +90,14 @@ public void testUnderflowAfterOverflow() addToState(state, TWO.pow(125)); assertEquals(state.getOverflow(), 1); - assertEquals(state.getLongDecimal(), unscaledDecimal(TWO.pow(125))); + assertEquals(getDecimalSlice(state), unscaledDecimal(TWO.pow(125))); addToState(state, TWO.pow(126).negate()); addToState(state, TWO.pow(126).negate()); addToState(state, TWO.pow(126).negate()); assertEquals(state.getOverflow(), 0); - assertEquals(state.getLongDecimal(), unscaledDecimal(TWO.pow(125).negate())); + assertEquals(getDecimalSlice(state), unscaledDecimal(TWO.pow(125).negate())); assertEquals(average(state, TYPE), new BigDecimal(TWO.pow(125).negate().divide(BigInteger.valueOf(6)))); } @@ -113,7 +116,7 @@ public void testCombineOverflow() DecimalAverageAggregation.combine(state, otherState); assertEquals(state.getLong(), 4); assertEquals(state.getOverflow(), 1); - assertEquals(state.getLongDecimal(), unscaledDecimal(TWO.pow(126))); + assertEquals(getDecimalSlice(state), unscaledDecimal(TWO.pow(126))); BigInteger expectedAverage = BigInteger.ZERO .add(TWO.pow(126)) @@ -139,7 +142,7 @@ public void testCombineUnderflow() DecimalAverageAggregation.combine(state, otherState); assertEquals(state.getLong(), 4); assertEquals(state.getOverflow(), -1); - assertEquals(state.getLongDecimal(), unscaledDecimal(TWO.pow(126).negate())); + assertEquals(getDecimalSlice(state), unscaledDecimal(TWO.pow(126).negate())); BigInteger expectedAverage = BigInteger.ZERO .add(TWO.pow(126)) @@ -163,4 +166,12 @@ private static void addToState(LongDecimalWithOverflowAndLongState state, BigInt DecimalAverageAggregation.inputLongDecimal(state, blockBuilder.build(), 0); } } + + private Slice getDecimalSlice(LongDecimalWithOverflowState state) + { + long[] decimal = state.getDecimalArray(); + int offset = state.getDecimalArrayOffset(); + + return Slices.wrappedLongArray(decimal[offset], decimal[offset + 1]); + } } diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDecimalSumAggregation.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDecimalSumAggregation.java index 39c5794318c0..93b60b3e0525 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDecimalSumAggregation.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDecimalSumAggregation.java @@ -13,6 +13,8 @@ */ package io.trino.operator.aggregation; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; import io.trino.operator.aggregation.state.LongDecimalWithOverflowState; import io.trino.operator.aggregation.state.LongDecimalWithOverflowStateFactory; import io.trino.spi.block.BlockBuilder; @@ -48,12 +50,12 @@ public void testOverflow() addToState(state, TWO.pow(126)); assertEquals(state.getOverflow(), 0); - assertEquals(state.getLongDecimal(), unscaledDecimal(TWO.pow(126))); + assertEquals(getDecimalSlice(state), unscaledDecimal(TWO.pow(126))); addToState(state, TWO.pow(126)); assertEquals(state.getOverflow(), 1); - assertEquals(state.getLongDecimal(), unscaledDecimal(0)); + assertEquals(getDecimalSlice(state), unscaledDecimal(0)); } @Test @@ -62,12 +64,12 @@ public void testUnderflow() addToState(state, TWO.pow(126).negate()); assertEquals(state.getOverflow(), 0); - assertEquals(state.getLongDecimal(), unscaledDecimal(TWO.pow(126).negate())); + assertEquals(getDecimalSlice(state), unscaledDecimal(TWO.pow(126).negate())); addToState(state, TWO.pow(126).negate()); assertEquals(state.getOverflow(), -1); - assertEquals(UnscaledDecimal128Arithmetic.compare(state.getLongDecimal(), unscaledDecimal(0)), 0); + assertEquals(UnscaledDecimal128Arithmetic.compare(getDecimalSlice(state), unscaledDecimal(0)), 0); } @Test @@ -78,14 +80,14 @@ public void testUnderflowAfterOverflow() addToState(state, TWO.pow(125)); assertEquals(state.getOverflow(), 1); - assertEquals(state.getLongDecimal(), unscaledDecimal(TWO.pow(125))); + assertEquals(getDecimalSlice(state), unscaledDecimal(TWO.pow(125))); addToState(state, TWO.pow(126).negate()); addToState(state, TWO.pow(126).negate()); addToState(state, TWO.pow(126).negate()); assertEquals(state.getOverflow(), 0); - assertEquals(state.getLongDecimal(), unscaledDecimal(TWO.pow(125).negate())); + assertEquals(getDecimalSlice(state), unscaledDecimal(TWO.pow(125).negate())); } @Test @@ -101,7 +103,7 @@ public void testCombineOverflow() DecimalSumAggregation.combine(state, otherState); assertEquals(state.getOverflow(), 1); - assertEquals(state.getLongDecimal(), unscaledDecimal(TWO.pow(126))); + assertEquals(getDecimalSlice(state), unscaledDecimal(TWO.pow(126))); } @Test @@ -117,7 +119,7 @@ public void testCombineUnderflow() DecimalSumAggregation.combine(state, otherState); assertEquals(state.getOverflow(), -1); - assertEquals(state.getLongDecimal(), unscaledDecimal(TWO.pow(126).negate())); + assertEquals(getDecimalSlice(state), unscaledDecimal(TWO.pow(126).negate())); } @Test(expectedExceptions = ArithmeticException.class) @@ -141,4 +143,12 @@ private static void addToState(LongDecimalWithOverflowState state, BigInteger va DecimalSumAggregation.inputLongDecimal(state, blockBuilder.build(), 0); } } + + private Slice getDecimalSlice(LongDecimalWithOverflowState state) + { + long[] decimal = state.getDecimalArray(); + int offset = state.getDecimalArrayOffset(); + + return Slices.wrappedLongArray(decimal[offset], decimal[offset + 1]); + } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/UnscaledDecimal128Arithmetic.java b/core/trino-spi/src/main/java/io/trino/spi/type/UnscaledDecimal128Arithmetic.java index 50067349a909..5c47e3cf5605 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/UnscaledDecimal128Arithmetic.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/UnscaledDecimal128Arithmetic.java @@ -45,13 +45,13 @@ public final class UnscaledDecimal128Arithmetic private static final int NUMBER_OF_INTS = 2 * NUMBER_OF_LONGS; public static final int UNSCALED_DECIMAL_128_SLICE_LENGTH = NUMBER_OF_LONGS * SIZE_OF_LONG; + public static final long SIGN_LONG_MASK = 1L << 63; private static final Slice[] POWERS_OF_TEN = new Slice[MAX_PRECISION]; private static final Slice[] POWERS_OF_FIVE = new Slice[MAX_PRECISION]; private static final int SIGN_LONG_INDEX = 1; private static final int SIGN_INT_INDEX = 3; - private static final long SIGN_LONG_MASK = 1L << 63; private static final int SIGN_INT_MASK = 1 << 31; private static final int SIGN_BYTE_MASK = 1 << 7; private static final long ALL_BITS_SET_64 = 0xFFFFFFFFFFFFFFFFL; @@ -184,6 +184,26 @@ public static BigInteger unscaledDecimalToBigInteger(Slice decimal) return new BigInteger(isNegative(decimal) ? -1 : 1, bytes); } + public static BigInteger unscaledDecimalToBigInteger(long rawLow, long rawHigh) + { + byte[] bytes = new byte[16]; + // convert to big-endian order + toByteArray(rawHigh, bytes, 0); + toByteArray(rawLow, bytes, 8); + bytes[0] &= ~SIGN_BYTE_MASK; + return new BigInteger(isNegative(rawHigh) ? -1 : 1, bytes); + } + + public static byte[] toByteArray(long value, byte[] result, int offset) + { + // copied from Guava Longs#toByteArray + for (int i = 7; i >= 0; i--) { + result[i + offset] = (byte) (value & 0xffL); + value >>= 8; + } + return result; + } + public static long unscaledDecimalToUnscaledLong(Slice decimal) { long low = getLong(decimal, 0); @@ -345,6 +365,21 @@ public static void add(Slice left, Slice right, Slice result) } } + // only visible for testing + static void add( + long leftRawLow, + long leftRawHigh, + long rightRawLow, + long rightRawHigh, + long[] result, + int resultOffset) + { + long overflow = addWithOverflow(leftRawLow, leftRawHigh, rightRawLow, rightRawHigh, result, resultOffset); + if (overflow != 0) { + throwOverflowException(); + } + } + /** * Instead of throwing overflow exception, this function returns: * 0 when there was no overflow @@ -378,6 +413,41 @@ else if (compare < 0) { return overflow; } + public static long addWithOverflow( + long leftRawLow, + long leftRawHigh, + long rightRawLow, + long rightRawHigh, + long[] result, + int resultOffset) + { + boolean leftNegative = isNegative(leftRawHigh); + boolean rightNegative = isNegative(rightRawHigh); + long overflow = 0; + if (leftNegative == rightNegative) { + // either both negative or both positive + overflow = addUnsignedReturnOverflow(leftRawLow, leftRawHigh, rightRawLow, rightRawHigh, result, resultOffset, leftNegative); + if (leftNegative) { + overflow = -overflow; + } + } + else { + int compare = compareAbsolute(leftRawLow, leftRawHigh, rightRawLow, rightRawHigh); + if (compare > 0) { + subtractUnsigned(leftRawLow, leftRawHigh, rightRawLow, rightRawHigh, result, resultOffset, leftNegative); + } + else if (compare < 0) { + subtractUnsigned(rightRawLow, rightRawHigh, leftRawLow, leftRawHigh, result, resultOffset, !leftNegative); + } + else { + // set to 0 + result[resultOffset] = 0; + result[resultOffset + 1] = 0; + } + } + return overflow; + } + public static Slice subtract(Slice left, Slice right) { Slice result = unscaledDecimal(); @@ -431,6 +501,28 @@ private static long addUnsignedReturnOverflow(Slice left, Slice right, Slice res return intermediateResult >>> 63; } + private static long addUnsignedReturnOverflow( + long leftRawLow, + long leftRawHigh, + long rightRawLow, + long rightRawHigh, + long[] result, + int resultOffset, + boolean resultNegative) + { + long leftHigh = unpackUnsignedLong(leftRawHigh); + long rightHigh = unpackUnsignedLong(rightRawHigh); + + long z0 = leftRawLow + rightRawLow; + int overflow = unsignedIsSmaller(z0, leftRawLow) ? 1 : 0; + + long intermediateResult = leftHigh + rightHigh + overflow; + long z1 = intermediateResult & (~SIGN_LONG_MASK); + pack(z0, z1, resultNegative, result, resultOffset); + + return intermediateResult >>> 63; + } + /** * This method ignores signs of the left and right and assumes that left is greater then right */ @@ -449,6 +541,22 @@ private static void subtractUnsigned(Slice left, Slice right, Slice result, bool pack(result, z0, z1, resultNegative); } + private static void subtractUnsigned( + long leftRawLow, + long leftRawHigh, + long rightRawLow, + long rightRawHigh, + long[] result, + int resultOffset, + boolean resultNegative) + { + long z0 = leftRawLow - rightRawLow; + int underflow = unsignedIsSmaller(leftRawLow, z0) ? 1 : 0; + long z1 = unpackUnsignedLong(leftRawHigh) - unpackUnsignedLong(rightRawHigh) - underflow; + + pack(z0, z1, resultNegative, result, resultOffset); + } + public static Slice multiply(Slice left, Slice right) { Slice result = unscaledDecimal(); @@ -918,6 +1026,25 @@ public static int compareAbsolute(Slice left, Slice right) return 0; } + public static int compareAbsolute( + long leftRawLow, + long leftRawHigh, + long rightRawLow, + long rightRawHigh) + { + long leftHigh = unpackUnsignedLong(leftRawHigh); + long rightHigh = unpackUnsignedLong(rightRawHigh); + if (leftHigh != rightHigh) { + return Long.compareUnsigned(leftHigh, rightHigh); + } + + if (leftRawLow != rightRawLow) { + return Long.compareUnsigned(leftRawLow, rightRawLow); + } + + return 0; + } + public static int compareUnsigned(long leftRawLow, long leftRawHigh, long rightRawLow, long rightRawHigh) { if (leftRawHigh != rightRawHigh) { @@ -1014,6 +1141,13 @@ public static void throwIfOverflows(Slice decimal) } } + public static void throwIfOverflows(long rawLow, long rawHigh) + { + if (exceedsOrEqualTenToThirtyEight(rawLow, rawHigh)) { + throwOverflowException(); + } + } + public static void throwIfOverflows(Slice value, int precision) { if (overflows(value, precision)) { @@ -1832,6 +1966,22 @@ private static boolean exceedsOrEqualTenToThirtyEight(Slice decimal) return low < 0 || low >= 0x098a224000000000L; } + private static boolean exceedsOrEqualTenToThirtyEight(long rawLow, long rawHigh) + { + // 10**38= + // i0 = 0(0), i1 = 160047680(98a22400), i2 = 1518781562(5a86c47a), i3 = 1262177448(4b3b4ca8) + // low = 0x98a2240000000000l, high = 0x4b3b4ca85a86c47al + long high = unpackUnsignedLong(rawHigh); + if (high >= 0 && high < 0x4b3b4ca85a86c47aL) { + return false; + } + if (high != 0x4b3b4ca85a86c47aL) { + return true; + } + + return rawLow < 0 || rawLow >= 0x098a224000000000L; + } + private static void reverse(final byte[] a) { final int length = a.length; diff --git a/core/trino-spi/src/test/java/io/trino/spi/type/TestUnscaledDecimal128Arithmetic.java b/core/trino-spi/src/test/java/io/trino/spi/type/TestUnscaledDecimal128Arithmetic.java index be28cdc214c5..69c480292a47 100644 --- a/core/trino-spi/src/test/java/io/trino/spi/type/TestUnscaledDecimal128Arithmetic.java +++ b/core/trino-spi/src/test/java/io/trino/spi/type/TestUnscaledDecimal128Arithmetic.java @@ -43,6 +43,7 @@ import static io.trino.spi.type.UnscaledDecimal128Arithmetic.shiftLeftMultiPrecision; import static io.trino.spi.type.UnscaledDecimal128Arithmetic.shiftRight; import static io.trino.spi.type.UnscaledDecimal128Arithmetic.shiftRightMultiPrecision; +import static io.trino.spi.type.UnscaledDecimal128Arithmetic.throwIfOverflows; import static io.trino.spi.type.UnscaledDecimal128Arithmetic.toUnscaledString; import static io.trino.spi.type.UnscaledDecimal128Arithmetic.unscaledDecimal; import static io.trino.spi.type.UnscaledDecimal128Arithmetic.unscaledDecimalToBigInteger; @@ -169,6 +170,18 @@ public void testAdd() assertAdd(unscaledDecimal(1L << 32), unscaledDecimal(1L << 33), unscaledDecimal((1L << 32) + (1L << 33))); } + @Test + public void testThrowIfOverflows() + { + Slice decimal = add(unscaledDecimal(MAX_DECIMAL_UNSCALED_VALUE), unscaledDecimal(1)); + assertThatThrownBy(() -> throwIfOverflows(decimal)) + .isInstanceOf(ArithmeticException.class) + .hasMessage("Decimal overflow"); + assertThatThrownBy(() -> throwIfOverflows(decimal.getLong(0), decimal.getLong(SIZE_OF_LONG))) + .isInstanceOf(ArithmeticException.class) + .hasMessage("Decimal overflow"); + } + @Test public void testAddReturnOverflow() { @@ -555,12 +568,25 @@ public void testShiftLeft() private void assertAdd(Slice left, Slice right, Slice result) { assertEquals(add(left, right), result); + + // test with array based version of the method + long[] resultArray = new long[2]; + add( + left.getLong(0), + left.getLong(SIZE_OF_LONG), + right.getLong(0), + right.getLong(SIZE_OF_LONG), + resultArray, + 0); + assertEquals(unscaledDecimalToBigInteger(resultArray[0], resultArray[1]), unscaledDecimalToBigInteger(result)); } private void assertAddReturnOverflow(BigInteger left, BigInteger right) { Slice result = unscaledDecimal(); - long overflow = addWithOverflow(unscaledDecimal(left), unscaledDecimal(right), result); + Slice leftSlice = unscaledDecimal(left); + Slice rightSlice = unscaledDecimal(right); + long overflow = addWithOverflow(leftSlice, rightSlice, result); BigInteger actual = unscaledDecimalToBigInteger(result); BigInteger expected = left.add(right).remainder(TWO.pow(UnscaledDecimal128Arithmetic.UNSCALED_DECIMAL_128_SLICE_LENGTH * 8 - 1)); @@ -568,6 +594,13 @@ private void assertAddReturnOverflow(BigInteger left, BigInteger right) assertEquals(actual, expected); assertEquals(overflow, expectedOverflow.longValueExact()); + + // test with array based version of the method + long[] resultArray = new long[2]; + overflow = addWithOverflow(leftSlice.getLong(0), leftSlice.getLong(SIZE_OF_LONG), rightSlice.getLong(0), rightSlice.getLong(SIZE_OF_LONG), resultArray, 0); + + assertEquals(unscaledDecimalToBigInteger(resultArray[0], resultArray[1]), expected); + assertEquals(overflow, expectedOverflow.longValueExact()); } private static void assertUnscaledBigIntegerToDecimalOverflows(BigInteger value) diff --git a/lib/trino-array/src/main/java/io/trino/array/BooleanBigArray.java b/lib/trino-array/src/main/java/io/trino/array/BooleanBigArray.java index e4c7316b2859..1ebbee503925 100644 --- a/lib/trino-array/src/main/java/io/trino/array/BooleanBigArray.java +++ b/lib/trino-array/src/main/java/io/trino/array/BooleanBigArray.java @@ -94,6 +94,11 @@ public void ensureCapacity(long length) grow(length); } + public long getCapacity() + { + return capacity; + } + /** * Fills the entire big array with the specified value. */ diff --git a/lib/trino-array/src/main/java/io/trino/array/LongBigArray.java b/lib/trino-array/src/main/java/io/trino/array/LongBigArray.java index 3d45cd47d6e1..b10be1ad93f7 100644 --- a/lib/trino-array/src/main/java/io/trino/array/LongBigArray.java +++ b/lib/trino-array/src/main/java/io/trino/array/LongBigArray.java @@ -74,6 +74,16 @@ public long get(long index) return array[segment(index)][offset(index)]; } + public long[] getSegment(long index) + { + return array[segment(index)]; + } + + public int getOffset(long index) + { + return offset(index); + } + /** * Sets the element of this big array at specified index. *