diff --git a/presto-spi/src/main/java/io/prestosql/spi/type/UnscaledDecimal128Arithmetic.java b/presto-spi/src/main/java/io/prestosql/spi/type/UnscaledDecimal128Arithmetic.java index bb7e89d14ee3..841a03247fb8 100644 --- a/presto-spi/src/main/java/io/prestosql/spi/type/UnscaledDecimal128Arithmetic.java +++ b/presto-spi/src/main/java/io/prestosql/spi/type/UnscaledDecimal128Arithmetic.java @@ -25,7 +25,6 @@ import static io.prestosql.spi.type.Decimals.MAX_PRECISION; import static io.prestosql.spi.type.Decimals.longTenToNth; import static java.lang.Integer.toUnsignedLong; -import static java.lang.String.format; import static java.lang.System.arraycopy; import static java.util.Arrays.fill; @@ -353,37 +352,20 @@ else if (compare < 0) { */ private static long addUnsignedReturnOverflow(Slice left, Slice right, Slice result, boolean resultNegative) { - // TODO: consider two 7 bytes operations - int l0 = getInt(left, 0); - int l1 = getInt(left, 1); - int l2 = getInt(left, 2); - int l3 = getInt(left, 3); + long l0 = getLong(left, 0); + long l1 = getLong(left, 1); - int r0 = getInt(right, 0); - int r1 = getInt(right, 1); - int r2 = getInt(right, 2); - int r3 = getInt(right, 3); + long r0 = getLong(right, 0); + long r1 = getLong(right, 1); - long intermediateResult; - intermediateResult = toUnsignedLong(l0) + toUnsignedLong(r0); + long z0 = l0 + r0; + int overflow = unsignedIsSmaller(z0, l0) ? 1 : 0; - int z0 = (int) intermediateResult; + long intermediateResult = l1 + r1 + overflow; + long z1 = intermediateResult & (~SIGN_LONG_MASK); + pack(result, z0, z1, resultNegative); - intermediateResult = toUnsignedLong(l1) + toUnsignedLong(r1) + (intermediateResult >>> 32); - - int z1 = (int) intermediateResult; - - intermediateResult = toUnsignedLong(l2) + toUnsignedLong(r2) + (intermediateResult >>> 32); - - int z2 = (int) intermediateResult; - - intermediateResult = toUnsignedLong(l3) + toUnsignedLong(r3) + (intermediateResult >>> 32); - - int z3 = (int) intermediateResult & (~SIGN_INT_MASK); - - pack(result, z0, z1, z2, z3, resultNegative); - - return intermediateResult >> 31; + return intermediateResult >>> 63; } /** @@ -391,39 +373,17 @@ private static long addUnsignedReturnOverflow(Slice left, Slice right, Slice res */ private static void subtractUnsigned(Slice left, Slice right, Slice result, boolean resultNegative) { - // TODO: consider two 7 bytes operations - int l0 = getInt(left, 0); - int l1 = getInt(left, 1); - int l2 = getInt(left, 2); - int l3 = getInt(left, 3); - - int r0 = getInt(right, 0); - int r1 = getInt(right, 1); - int r2 = getInt(right, 2); - int r3 = getInt(right, 3); - - long intermediateResult; - intermediateResult = toUnsignedLong(l0) - toUnsignedLong(r0); - - int z0 = (int) intermediateResult; - - intermediateResult = toUnsignedLong(l1) - toUnsignedLong(r1) + (intermediateResult >> 32); + long l0 = getLong(left, 0); + long l1 = getLong(left, 1); - int z1 = (int) intermediateResult; + long r0 = getLong(right, 0); + long r1 = getLong(right, 1); - intermediateResult = toUnsignedLong(l2) - toUnsignedLong(r2) + (intermediateResult >> 32); + long z0 = l0 - r0; + int underflow = unsignedIsSmaller(l0, z0) ? 1 : 0; + long z1 = l1 - r1 - underflow; - int z2 = (int) intermediateResult; - - intermediateResult = toUnsignedLong(l3) - toUnsignedLong(r3) + (intermediateResult >> 32); - - int z3 = (int) intermediateResult; - - pack(result, z0, z1, z2, z3, resultNegative); - - if ((intermediateResult >> 32) != 0) { - throw new IllegalStateException(format("Non empty carry over after subtracting [%d]. right > left?", (intermediateResult >> 32))); - } + pack(result, z0, z1, resultNegative); } public static Slice multiply(Slice left, Slice right) @@ -1737,6 +1697,14 @@ private static void setRawLong(Slice decimal, int index, long value) decimal.setLong(SIZE_OF_LONG * index, value); } + /** + * Based on Long.compareUnsigned() + */ + private static boolean unsignedIsSmaller(long first, long second) + { + return first + Long.MIN_VALUE < second + Long.MIN_VALUE; + } + private static void checkArgument(boolean condition) { if (!condition) {