Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove redundant boolean state from LongDecimalWithOverflowAndLongState #16667

Merged
merged 1 commit into from
Mar 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,6 @@ public static void inputShortDecimal(
{
state.addLong(1); // row counter

state.setNotNull();

long[] decimal = state.getDecimalArray();
int offset = state.getDecimalArrayOffset();

Expand All @@ -89,8 +87,6 @@ public static void inputLongDecimal(
{
state.addLong(1); // row counter

state.setNotNull();

long[] decimal = state.getDecimalArray();
int offset = state.getDecimalArrayOffset();

Expand Down Expand Up @@ -119,7 +115,7 @@ public static void combine(@AggregationState LongDecimalWithOverflowAndLongState
long[] otherDecimal = otherState.getDecimalArray();
int otherOffset = otherState.getDecimalArrayOffset();

if (state.isNotNull()) {
if (state.getLong() > 0) {
long overflow = addWithOverflow(
decimal[offset],
decimal[offset + 1],
Expand All @@ -130,7 +126,6 @@ public static void combine(@AggregationState LongDecimalWithOverflowAndLongState
state.addOverflow(overflow + otherState.getOverflow());
}
else {
state.setNotNull();
decimal[offset] = otherDecimal[otherOffset];
decimal[offset + 1] = otherDecimal[otherOffset + 1];
state.setOverflow(otherState.getOverflow());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,26 @@
*/
package io.trino.operator.aggregation.state;

import io.trino.spi.function.AccumulatorState;
import io.trino.spi.function.AccumulatorStateMetadata;

@AccumulatorStateMetadata(stateFactoryClass = LongDecimalWithOverflowAndLongStateFactory.class, stateSerializerClass = LongDecimalWithOverflowAndLongStateSerializer.class)
public interface LongDecimalWithOverflowAndLongState
extends LongDecimalWithOverflowState
extends AccumulatorState
{
long getLong();

void setLong(long value);

void addLong(long value);

long[] getDecimalArray();

int getDecimalArrayOffset();

long getOverflow();

void setOverflow(long overflow);

void addOverflow(long overflow);
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
import io.trino.spi.function.AccumulatorState;
import io.trino.spi.function.AccumulatorStateFactory;

import javax.annotation.Nullable;

import static io.airlift.slice.SizeOf.instanceSize;
import static io.airlift.slice.SizeOf.sizeOf;
import static java.lang.System.arraycopy;

public class LongDecimalWithOverflowAndLongStateFactory
implements AccumulatorStateFactory<LongDecimalWithOverflowAndLongState>
Expand All @@ -35,17 +39,26 @@ public LongDecimalWithOverflowAndLongState createGroupedState()
}

public static class GroupedLongDecimalWithOverflowAndLongState
extends LongDecimalWithOverflowStateFactory.GroupedLongDecimalWithOverflowState
extends AbstractGroupedAccumulatorState
implements LongDecimalWithOverflowAndLongState
{
private static final int INSTANCE_SIZE = instanceSize(GroupedLongDecimalWithOverflowAndLongState.class);
private final LongBigArray longs = new LongBigArray();
/**
* Stores 128-bit decimals as pairs of longs
*/
private final LongBigArray unscaledDecimals = new LongBigArray();
@Nullable
private LongBigArray overflows; // lazily initialized on the first overflow

@Override
public void ensureCapacity(long size)
{
longs.ensureCapacity(size);
super.ensureCapacity(size);
unscaledDecimals.ensureCapacity(size * 2);
if (overflows != null) {
overflows.ensureCapacity(size);
}
}

@Override
Expand All @@ -66,27 +79,80 @@ public void addLong(long value)
longs.add(getGroupId(), value);
}

@Override
public long[] getDecimalArray()
{
return unscaledDecimals.getSegment(getGroupId() * 2);
}

@Override
public int getDecimalArrayOffset()
{
return unscaledDecimals.getOffset(getGroupId() * 2);
}

@Override
public long getOverflow()
{
if (overflows == null) {
return 0;
}
return overflows.get(getGroupId());
}

@Override
public void setOverflow(long overflow)
{
// setOverflow(0) must overwrite any existing overflow value
if (overflow == 0 && overflows == null) {
return;
}
long groupId = getGroupId();
if (overflows == null) {
overflows = new LongBigArray();
overflows.ensureCapacity(longs.getCapacity());
}
overflows.set(groupId, overflow);
}

@Override
public void addOverflow(long overflow)
{
if (overflow != 0) {
long groupId = getGroupId();
if (overflows == null) {
overflows = new LongBigArray();
overflows.ensureCapacity(longs.getCapacity());
}
overflows.add(groupId, overflow);
}
}

@Override
public long getEstimatedSize()
{
return INSTANCE_SIZE + longs.sizeOf() + isNotNull.sizeOf() + unscaledDecimals.sizeOf() + (overflows == null ? 0 : overflows.sizeOf());
return INSTANCE_SIZE + longs.sizeOf() + unscaledDecimals.sizeOf() + (overflows == null ? 0 : overflows.sizeOf());
}
}

public static class SingleLongDecimalWithOverflowAndLongState
extends LongDecimalWithOverflowStateFactory.SingleLongDecimalWithOverflowState
implements LongDecimalWithOverflowAndLongState
{
private static final int INSTANCE_SIZE = instanceSize(SingleLongDecimalWithOverflowAndLongState.class);
private static final int SIZE = (int) sizeOf(new long[2]);

protected long longValue;
private final long[] unscaledDecimal = new long[2];
private long longValue;
private long overflow;

public SingleLongDecimalWithOverflowAndLongState() {}

// for copying
private SingleLongDecimalWithOverflowAndLongState(long longValue)
private SingleLongDecimalWithOverflowAndLongState(long[] unscaledDecimal, long longValue, long overflow)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand. Was this a bug? It wasn't delegating to super constructor.

Does this deserve a separate commit/PR?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looked like a mistake in the current code to me but I don't know enough about how it is used to definitively say whether it would've resulted in a bug.
@kasiafi could you please take a look and clarify.

{
arraycopy(unscaledDecimal, 0, this.unscaledDecimal, 0, 2);
this.longValue = longValue;
this.overflow = overflow;
}

@Override
Expand All @@ -107,6 +173,36 @@ public void addLong(long value)
longValue += value;
}

@Override
public long[] getDecimalArray()
{
return unscaledDecimal;
}

@Override
public int getDecimalArrayOffset()
{
return 0;
}

@Override
public long getOverflow()
{
return overflow;
}

@Override
public void setOverflow(long overflow)
{
this.overflow = overflow;
}

@Override
public void addOverflow(long overflow)
{
this.overflow += overflow;
}

@Override
public long getEstimatedSize()
{
Expand All @@ -116,7 +212,7 @@ public long getEstimatedSize()
@Override
public AccumulatorState copy()
{
return new SingleLongDecimalWithOverflowAndLongState(longValue);
return new SingleLongDecimalWithOverflowAndLongState(unscaledDecimal, longValue, overflow);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ public Type getSerializedType()
@Override
public void serialize(LongDecimalWithOverflowAndLongState state, BlockBuilder out)
{
if (state.isNotNull()) {
long count = state.getLong();
long count = state.getLong();
if (count > 0) {
long overflow = state.getOverflow();
long[] decimal = state.getDecimalArray();
int offset = state.getDecimalArrayOffset();
Expand Down Expand Up @@ -97,7 +97,6 @@ public void deserialize(Block block, int index, LongDecimalWithOverflowAndLongSt
decimal[offset] = high;
state.setOverflow(overflow);
state.setLong(count);
state.setNotNull();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ public static class GroupedLongDecimalWithOverflowState
implements LongDecimalWithOverflowState
{
private static final int INSTANCE_SIZE = instanceSize(GroupedLongDecimalWithOverflowState.class);
protected final BooleanBigArray isNotNull = new BooleanBigArray();
private final BooleanBigArray isNotNull = new BooleanBigArray();
/**
* Stores 128-bit decimals as pairs of longs
*/
protected final LongBigArray unscaledDecimals = new LongBigArray();
private final LongBigArray unscaledDecimals = new LongBigArray();
@Nullable
protected LongBigArray overflows; // lazily initialized on the first overflow
private LongBigArray overflows; // lazily initialized on the first overflow

@Override
public void ensureCapacity(long size)
Expand Down Expand Up @@ -134,11 +134,11 @@ public static class SingleLongDecimalWithOverflowState
implements LongDecimalWithOverflowState
{
private static final int INSTANCE_SIZE = instanceSize(SingleLongDecimalWithOverflowState.class);
protected static final int SIZE = (int) sizeOf(new long[2]);
private static final int SIZE = (int) sizeOf(new long[2]);

protected final long[] unscaledDecimal = new long[2];
protected boolean isNotNull;
protected long overflow;
private final long[] unscaledDecimal = new long[2];
private boolean isNotNull;
private long overflow;

public SingleLongDecimalWithOverflowState() {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import com.google.common.collect.ImmutableList;
import io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongState;
import io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongStateFactory;
import io.trino.operator.aggregation.state.LongDecimalWithOverflowState;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.Decimals;
Expand Down Expand Up @@ -238,7 +237,7 @@ private static void addToState(DecimalType type, LongDecimalWithOverflowAndLongS
}
}

private Int128 getDecimal(LongDecimalWithOverflowState state)
private Int128 getDecimal(LongDecimalWithOverflowAndLongState state)
{
long[] decimal = state.getDecimalArray();
int offset = state.getDecimalArrayOffset();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
import org.testng.annotations.Test;

import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertTrue;

public class TestLongDecimalWithOverflowAndLongStateSerializer
{
Expand All @@ -35,11 +33,9 @@ public void testSerde(long low, long high, long overflow, long count, int expect
state.getDecimalArray()[1] = low;
state.setOverflow(overflow);
state.setLong(count);
state.setNotNull();

LongDecimalWithOverflowAndLongState outState = roundTrip(state, expectedLength);

assertTrue(outState.isNotNull());
assertEquals(outState.getDecimalArray()[0], high);
assertEquals(outState.getDecimalArray()[1], low);
assertEquals(outState.getOverflow(), overflow);
Expand All @@ -54,7 +50,7 @@ public void testNullSerde()

LongDecimalWithOverflowAndLongState outState = roundTrip(state, 0);

assertFalse(outState.isNotNull());
assertEquals(outState.getLong(), 0);
}

private LongDecimalWithOverflowAndLongState roundTrip(LongDecimalWithOverflowAndLongState state, int expectedLength)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ public void ensureCapacity(long length)
grow(length);
}

public long getCapacity()
{
return capacity;
}

/**
* Copies this array, beginning at the specified sourceIndex, to the specified destinationIndex of
* the destination array. A subsequence of this array's components are copied to the destination
Expand Down