diff --git a/CHANGELOG.md b/CHANGELOG.md index 96da44e83e1..7666c304169 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -76,6 +76,7 @@ - PR #6727 Remove 2nd type-dispatcher call from cudf::reduce - PR #6749 Update nested JNI builder so we can do it incrementally - PR #6748 Add Java API to concatenate serialized tables to ContiguousTable +- PR #6734 Binary operations support for decimal type in cudf Java ## Bug Fixes diff --git a/java/src/main/java/ai/rapids/cudf/BinaryOperable.java b/java/src/main/java/ai/rapids/cudf/BinaryOperable.java index c0bd0565a37..d0c45d1af2a 100644 --- a/java/src/main/java/ai/rapids/cudf/BinaryOperable.java +++ b/java/src/main/java/ai/rapids/cudf/BinaryOperable.java @@ -34,6 +34,9 @@ public interface BinaryOperable { *
* BOOL8 is treated like an INT8. Math on boolean operations makes little sense. If * you want to stay as a BOOL8 you will need to explicitly specify the output type. + * For decimal types, DECIMAL32 and DECIMAL64 takes in another parameter `scale`. DType is created + * with scale=0 as scale is required. Dtype is discarded for binary operations for decimal + * types in cudf as a new DType is created for output type with the new scale. */ static DType implicitConversion(BinaryOperable lhs, BinaryOperable rhs) { DType a = lhs.getType(); @@ -76,6 +79,25 @@ static DType implicitConversion(BinaryOperable lhs, BinaryOperable rhs) { if (a == DType.BOOL8 || b == DType.BOOL8) { return DType.BOOL8; } + if (a.isDecimalType() && b.isDecimalType()) { + // Here scale is created with value 0 as `scale` is required to create DType of + // decimal type. Dtype is discarded for binary operations for decimal types in cudf as a new + // DType is created for output type with new scale. New scale for output depends upon operator. + int scale = 0; + if (a.typeId == DType.DTypeEnum.DECIMAL32) { + if (b.typeId == DType.DTypeEnum.DECIMAL32) { + return DType.create(DType.DTypeEnum.DECIMAL32, scale); + } else { + throw new IllegalArgumentException("Both columns must be of the same fixed_point type"); + } + } else if (a.typeId == DType.DTypeEnum.DECIMAL64) { + if (b.typeId == DType.DTypeEnum.DECIMAL64) { + return DType.create(DType.DTypeEnum.DECIMAL64, scale); + } else { + throw new IllegalArgumentException("Both columns must be of the same fixed_point type"); + } + } + } throw new IllegalArgumentException("Unsupported types " + a + " and " + b); } @@ -94,7 +116,9 @@ static DType implicitConversion(BinaryOperable lhs, BinaryOperable rhs) { ColumnVector binaryOp(BinaryOp op, BinaryOperable rhs, DType outType); /** - * Add + operator. this + rhs + * Add one vector to another with the given output type. this + rhs + * Output type is ignored for the operations between decimal types and + * it is always decimal type. */ default ColumnVector add(BinaryOperable rhs, DType outType) { return binaryOp(BinaryOp.ADD, rhs, outType); @@ -109,6 +133,8 @@ default ColumnVector add(BinaryOperable rhs) { /** * Subtract one vector from another with the given output type. this - rhs + * Output type is ignored for the operations between decimal types and + * it is always decimal type. */ default ColumnVector sub(BinaryOperable rhs, DType outType) { return binaryOp(BinaryOp.SUB, rhs, outType); @@ -123,6 +149,8 @@ default ColumnVector sub(BinaryOperable rhs) { /** * Multiply two vectors together with the given output type. this * rhs + * Output type is ignored for the operations between decimal types and + * it is always decimal type. */ default ColumnVector mul(BinaryOperable rhs, DType outType) { return binaryOp(BinaryOp.MUL, rhs, outType); @@ -137,6 +165,8 @@ default ColumnVector mul(BinaryOperable rhs) { /** * Divide one vector by another with the given output type. this / rhs + * Output type is ignored for the operations between decimal types and + * it is always decimal type. */ default ColumnVector div(BinaryOperable rhs, DType outType) { return binaryOp(BinaryOp.DIV, rhs, outType); diff --git a/java/src/main/java/ai/rapids/cudf/HostColumnVector.java b/java/src/main/java/ai/rapids/cudf/HostColumnVector.java index 7e1341644bc..5616e1e2744 100644 --- a/java/src/main/java/ai/rapids/cudf/HostColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/HostColumnVector.java @@ -544,8 +544,8 @@ public static HostColumnVector fromDecimals(BigDecimal... values) { BigDecimal maxDec = Arrays.stream(values).filter(Objects::nonNull) .max(Comparator.comparingInt(BigDecimal::precision)) .orElse(BigDecimal.ZERO); - int maxScale = Arrays.stream(values) - .map(decimal -> (decimal == null) ? 0 : decimal.scale()) + int maxScale = Arrays.stream(values).filter(Objects::nonNull) + .map(decimal -> decimal.scale()) .max(Comparator.naturalOrder()) .orElse(0); maxDec = maxDec.setScale(maxScale, RoundingMode.UNNECESSARY); @@ -1364,8 +1364,10 @@ public final Builder append(BigDecimal value, RoundingMode roundingMode) { assert currentIndex < rows; BigInteger unscaledValue = value.setScale(-type.getScale(), roundingMode).unscaledValue(); if (type.typeId == DType.DTypeEnum.DECIMAL32) { + assert value.precision() <= DType.DECIMAL32_MAX_PRECISION : "value exceeds maximum precision for DECIMAL32"; data.setInt(currentIndex * type.getSizeInBytes(), unscaledValue.intValueExact()); } else if (type.typeId == DType.DTypeEnum.DECIMAL64) { + assert value.precision() <= DType.DECIMAL64_MAX_PRECISION : "value exceeds maximum precision for DECIMAL64 "; data.setLong(currentIndex * type.getSizeInBytes(), unscaledValue.longValueExact()); } else { throw new IllegalStateException(type + " is not a supported decimal type."); diff --git a/java/src/test/java/ai/rapids/cudf/BinaryOpTest.java b/java/src/test/java/ai/rapids/cudf/BinaryOpTest.java index 1342abca041..df4afb5ff60 100644 --- a/java/src/test/java/ai/rapids/cudf/BinaryOpTest.java +++ b/java/src/test/java/ai/rapids/cudf/BinaryOpTest.java @@ -21,13 +21,21 @@ import ai.rapids.cudf.HostColumnVector.Builder; import org.junit.jupiter.api.Test; +import java.math.BigDecimal; +import java.math.RoundingMode; import java.util.Arrays; import java.util.stream.IntStream; import static ai.rapids.cudf.TableTest.assertColumnsAreEqual; import static ai.rapids.cudf.TestUtils.*; +import static org.junit.jupiter.api.Assertions.assertThrows; public class BinaryOpTest extends CudfTestBase { + private static final int dec32Scale_1 = 2; + private static final int dec32Scale_2 = -3; + private static final int dec64Scale_1 = 6; + private static final int dec64Scale_2 = -2; + private static final Integer[] INTS_1 = new Integer[]{1, 2, 3, 4, 5, null, 100}; private static final Integer[] INTS_2 = new Integer[]{10, 20, 30, 40, 50, 60, 100}; private static final Integer[] UINTS_1 = new Integer[]{10, -20, 30, -40, 50, -60, 100}; @@ -43,6 +51,26 @@ public class BinaryOpTest extends CudfTestBase { private static final Boolean[] BOOLEANS_1 = new Boolean[]{true, true, false, false, null}; private static final Boolean[] BOOLEANS_2 = new Boolean[]{true, false, true, false, true}; private static final int[] SHIFT_BY = new int[]{1, 2, 3, 4, 5, 10, 20}; + private static final int[] DECIMAL32_1 = new int[]{1000, 2000, 3000, 4000, 5000}; + private static final int[] DECIMAL32_2 = new int[]{100, 200, 300, 400, 50}; + private static final long[] DECIMAL64_1 = new long[]{10L, 23L, 12L, 24L, 123456789L}; + private static final long[] DECIMAL64_2 = new long[]{20L, 13L, 22L, 14L, 132457689L}; + + private static final BigDecimal[] BIGDECIMAL32_1 = new BigDecimal[]{ + BigDecimal.valueOf(12, dec32Scale_1), + BigDecimal.valueOf(11, dec32Scale_1), + BigDecimal.valueOf(20, dec32Scale_1), + null, + BigDecimal.valueOf(25, dec32Scale_1) + }; + + private static final BigDecimal[] BIGDECIMAL32_2 = new BigDecimal[]{ + BigDecimal.valueOf(12, dec32Scale_2), + BigDecimal.valueOf(2, dec32Scale_2), + null, + BigDecimal.valueOf(16, dec32Scale_2), + BigDecimal.valueOf(10, dec32Scale_2) + }; interface CpuOpVV { void computeNullSafe(Builder ret, HostColumnVector lhs, HostColumnVector rhs, int index); @@ -218,7 +246,11 @@ public void testAdd() { ColumnVector lcv2 = ColumnVector.fromBoxedLongs(LONGS_2); ColumnVector ulcv1 = ColumnVector.fromBoxedUnsignedLongs(LONGS_1); ColumnVector dcv1 = ColumnVector.fromBoxedDoubles(DOUBLES_1); - ColumnVector dcv2 = ColumnVector.fromBoxedDoubles(DOUBLES_2)) { + ColumnVector dcv2 = ColumnVector.fromBoxedDoubles(DOUBLES_2); + ColumnVector dec32cv1 = ColumnVector.fromDecimals(BIGDECIMAL32_1); + ColumnVector dec32cv2 = ColumnVector.fromDecimals(BIGDECIMAL32_2); + ColumnVector dec64cv1 = ColumnVector.decimalFromLongs(-dec64Scale_1, DECIMAL64_1); + ColumnVector dec64cv2 = ColumnVector.decimalFromLongs(-dec64Scale_2, DECIMAL64_2)) { try (ColumnVector add = icv1.add(icv2); ColumnVector expected = forEach(DType.INT32, icv1, icv2, (b, l, r, i) -> b.append(l.getInt(i) + r.getInt(i)))) { @@ -283,6 +315,31 @@ public void testAdd() { assertColumnsAreEqual(addIntFirst, addDoubleFirst, "int + double vs double + int"); } + try (ColumnVector add = dec32cv1.add(dec32cv2)) { + try (ColumnVector expected = forEach( + DType.create(DType.DTypeEnum.DECIMAL32, -2), dec32cv1, dec32cv2, + (b, l, r, i) -> b.append(l.getBigDecimal(i).add(r.getBigDecimal(i))))) { + assertColumnsAreEqual(expected, add, "dec32"); + } + } + + try (ColumnVector add = dec64cv1.add(dec64cv2)) { + try (ColumnVector expected = forEach( + DType.create(DType.DTypeEnum.DECIMAL64, -6), dec64cv1, dec64cv2, + (b, l, r, i) -> b.append(l.getBigDecimal(i).add(r.getBigDecimal(i))))) { + assertColumnsAreEqual(expected, add, "dec64"); + } + } + + try (Scalar s = Scalar.fromDecimal(2, 100); + ColumnVector add = dec32cv1.add(s)) { + try (ColumnVector expected = forEachS( + DType.create(DType.DTypeEnum.DECIMAL32, -2), dec32cv1, BigDecimal.valueOf(100, -2), + (b, l, r, i) -> b.append(l.getBigDecimal(i).add(r)))) { + assertColumnsAreEqual(expected, add, "dec32 + scalar"); + } + } + try (Scalar s = Scalar.fromFloat(1.1f); ColumnVector add = lcv1.add(s); ColumnVector expected = forEachS(DType.FLOAT32, lcv1, 1.1f, @@ -320,7 +377,11 @@ public void testSub() { ColumnVector lcv2 = ColumnVector.fromBoxedLongs(LONGS_2); ColumnVector ulcv1 = ColumnVector.fromBoxedUnsignedLongs(LONGS_1); ColumnVector dcv1 = ColumnVector.fromBoxedDoubles(DOUBLES_1); - ColumnVector dcv2 = ColumnVector.fromBoxedDoubles(DOUBLES_2)) { + ColumnVector dcv2 = ColumnVector.fromBoxedDoubles(DOUBLES_2); + ColumnVector dec32cv1 = ColumnVector.fromDecimals(BIGDECIMAL32_1); + ColumnVector dec32cv2 = ColumnVector.fromDecimals(BIGDECIMAL32_2); + ColumnVector dec64cv1 = ColumnVector.decimalFromLongs(-dec64Scale_1, DECIMAL64_1); + ColumnVector dec64cv2 = ColumnVector.decimalFromLongs(-dec64Scale_2, DECIMAL64_2)) { try (ColumnVector sub = icv1.sub(icv2); ColumnVector expected = forEach(DType.INT32, icv1, icv2, (b, l, r, i) -> b.append(l.getInt(i) - r.getInt(i)))) { @@ -387,6 +448,31 @@ public void testSub() { assertColumnsAreEqual(expected, sub, "double - int"); } + try (ColumnVector sub = dec32cv1.sub(dec32cv2)) { + try (ColumnVector expected = forEach( + DType.create(DType.DTypeEnum.DECIMAL32, -2), dec32cv1, dec32cv2, + (b, l, r, i) -> b.append(l.getBigDecimal(i).subtract(r.getBigDecimal(i))))) { + assertColumnsAreEqual(expected, sub, "dec32"); + } + } + + try (ColumnVector sub = dec64cv1.sub(dec64cv2)) { + try (ColumnVector expected = forEach( + DType.create(DType.DTypeEnum.DECIMAL64, -6), dec64cv1, dec64cv2, + (b, l, r, i) -> b.append(l.getBigDecimal(i).subtract(r.getBigDecimal(i))))) { + assertColumnsAreEqual(expected, sub, "dec64"); + } + } + + try (Scalar s = Scalar.fromDecimal(2, 100); + ColumnVector sub = dec32cv1.sub(s)) { + try (ColumnVector expected = forEachS( + DType.create(DType.DTypeEnum.DECIMAL32, -2), dec32cv1, BigDecimal.valueOf(100, -2), + (b, l, r, i) -> b.append(l.getBigDecimal(i).subtract(r)))) { + assertColumnsAreEqual(expected, sub, "dec32 - scalar"); + } + } + try (Scalar s = Scalar.fromFloat(1.1f); ColumnVector sub = lcv1.sub(s); ColumnVector expected = forEachS(DType.FLOAT32, lcv1, 1.1f, @@ -417,13 +503,42 @@ public void testSub() { @Test public void testMul() { try (ColumnVector icv = ColumnVector.fromBoxedInts(INTS_1); - ColumnVector dcv = ColumnVector.fromBoxedDoubles(DOUBLES_1)) { + ColumnVector dcv = ColumnVector.fromBoxedDoubles(DOUBLES_1); + ColumnVector dec32cv1 = ColumnVector.fromDecimals(BIGDECIMAL32_1); + ColumnVector dec32cv2 = ColumnVector.fromDecimals(BIGDECIMAL32_2); + ColumnVector dec64cv1 = ColumnVector.decimalFromLongs(-dec64Scale_1, DECIMAL64_1); + ColumnVector dec64cv2 = ColumnVector.decimalFromLongs(-dec64Scale_2, DECIMAL64_2)) { try (ColumnVector answer = icv.mul(dcv); ColumnVector expected = forEach(DType.FLOAT64, icv, dcv, (b, l, r, i) -> b.append(l.getInt(i) * r.getDouble(i)))) { assertColumnsAreEqual(expected, answer, "int32 * double"); } + try (ColumnVector mul = dec32cv1.mul(dec32cv2)) { + try (ColumnVector expected = forEach( + DType.create(DType.DTypeEnum.DECIMAL32, 1), dec32cv1, dec32cv2, + (b, l, r, i) -> b.append(l.getBigDecimal(i).multiply(r.getBigDecimal(i))))) { + assertColumnsAreEqual(expected, mul, "dec32"); + } + } + + try (ColumnVector mul = dec64cv1.mul(dec64cv2)) { + try (ColumnVector expected = forEach( + DType.create(DType.DTypeEnum.DECIMAL64, -4), dec64cv1, dec64cv2, + (b, l, r, i) -> b.append(l.getBigDecimal(i).multiply(r.getBigDecimal(i))))) { + assertColumnsAreEqual(expected, mul, "dec64"); + } + } + + try (Scalar s = Scalar.fromDecimal(2, 100); + ColumnVector mul = dec32cv1.mul(s)) { + try (ColumnVector expected = forEachS( + DType.create(DType.DTypeEnum.DECIMAL32, 0), dec32cv1, BigDecimal.valueOf(100, -2), + (b, l, r, i) -> b.append(l.getBigDecimal(i).multiply(r)))) { + assertColumnsAreEqual(expected, mul, "dec32 * scalar"); + } + } + try (Scalar s = Scalar.fromFloat(1.1f); ColumnVector answer = icv.mul(s); ColumnVector expected = forEachS(DType.FLOAT32, icv, 1.1f, @@ -451,13 +566,44 @@ public void testMul() { @Test public void testDiv() { try (ColumnVector icv = ColumnVector.fromBoxedInts(INTS_1); - ColumnVector dcv = ColumnVector.fromBoxedDoubles(DOUBLES_1)) { + ColumnVector dcv = ColumnVector.fromBoxedDoubles(DOUBLES_1); + ColumnVector dec32cv1 = ColumnVector.fromDecimals(BIGDECIMAL32_1); + ColumnVector dec32cv2 = ColumnVector.fromDecimals(BIGDECIMAL32_2); + ColumnVector dec64cv1 = ColumnVector.decimalFromLongs(-dec64Scale_1, DECIMAL64_1); + ColumnVector dec64cv2 = ColumnVector.decimalFromLongs(-dec64Scale_2, DECIMAL64_2)) { try (ColumnVector answer = icv.div(dcv); ColumnVector expected = forEach(DType.FLOAT64, icv, dcv, (b, l, r, i) -> b.append(l.getInt(i) / r.getDouble(i)))) { assertColumnsAreEqual(expected, answer, "int32 / double"); } + try (ColumnVector div = dec32cv1.div(dec32cv2)) { + try (ColumnVector expected = forEach( + DType.create(DType.DTypeEnum.DECIMAL32, -5), dec32cv1, dec32cv2, + (b, l, r, i) -> b.append(l.getBigDecimal(i).divide( + r.getBigDecimal(i), 5, RoundingMode.DOWN), RoundingMode.DOWN))) { + assertColumnsAreEqual(expected, div, "dec32"); + } + } + + try (ColumnVector div = dec64cv1.div(dec64cv2)) { + try (ColumnVector expected = forEach( + DType.create(DType.DTypeEnum.DECIMAL64, -8), dec64cv1, dec64cv2, + (b, l, r, i) -> b.append(l.getBigDecimal(i).divide( + r.getBigDecimal(i), 8, RoundingMode.DOWN), RoundingMode.DOWN))) { + assertColumnsAreEqual(expected, div, "dec64"); + } + } + + try (Scalar s = Scalar.fromDecimal(2, 100); + ColumnVector div = s.div(dec32cv1)) { + try (ColumnVector expected = forEachS( + DType.create(DType.DTypeEnum.DECIMAL32, 4), BigDecimal.valueOf(100, -2), dec32cv1, + (b, l, r, i) -> b.append(l.divide(r.getBigDecimal(i), -4, RoundingMode.DOWN)))) { + assertColumnsAreEqual(expected, div, "scalar dec32 / dec32"); + } + } + try (Scalar s = Scalar.fromFloat(1.1f); ColumnVector answer = icv.div(s); ColumnVector expected = forEachS(DType.FLOAT32, icv, 1.1f, @@ -597,13 +743,29 @@ public void testPow() { @Test public void testEqual() { try (ColumnVector icv = ColumnVector.fromBoxedInts(INTS_1); - ColumnVector dcv = ColumnVector.fromBoxedDoubles(DOUBLES_1)) { + ColumnVector dcv = ColumnVector.fromBoxedDoubles(DOUBLES_1); + ColumnVector dec32cv_1 = ColumnVector.decimalFromInts(-dec32Scale_1, DECIMAL32_1); + ColumnVector dec32cv_2 = ColumnVector.decimalFromInts(-dec32Scale_2, DECIMAL32_2)) { try (ColumnVector answer = icv.equalTo(dcv); ColumnVector expected = forEach(DType.BOOL8, icv, dcv, (b, l, r, i) -> b.append(l.getInt(i) == r.getDouble(i)))) { assertColumnsAreEqual(expected, answer, "int32 == double"); } + try (ColumnVector answer = dec32cv_1.equalTo(dec32cv_2); + ColumnVector expected = forEach(DType.BOOL8, dec32cv_1, dec32cv_2, + (b, l, r, i) -> b.append(l.getBigDecimal(i).compareTo(r.getBigDecimal(i)) == 0))) { + assertColumnsAreEqual(expected, answer, "dec32 == dec32 "); + } + + try (Scalar s = Scalar.fromDecimal(-2, 200); + ColumnVector answer = dec32cv_2.equalTo(s)) { + try (ColumnVector expected = forEachS(DType.BOOL8, dec32cv_1, BigDecimal.valueOf(200, 2), + (b, l, r, i) -> b.append(l.getBigDecimal(i).compareTo(r) == 0))) { + assertColumnsAreEqual(expected, answer, "dec32 == scalar dec32"); + } + } + try (Scalar s = Scalar.fromFloat(1.0f); ColumnVector answer = icv.equalTo(s); ColumnVector expected = forEachS(DType.BOOL8, icv, 1.0f, @@ -675,13 +837,29 @@ public void testStringEqualScalarNotPresent() { @Test public void testNotEqual() { try (ColumnVector icv = ColumnVector.fromBoxedInts(INTS_1); - ColumnVector dcv = ColumnVector.fromBoxedDoubles(DOUBLES_1)) { + ColumnVector dcv = ColumnVector.fromBoxedDoubles(DOUBLES_1); + ColumnVector dec32cv_1 = ColumnVector.decimalFromInts(-dec32Scale_1, DECIMAL32_1); + ColumnVector dec32cv_2 = ColumnVector.decimalFromInts(-dec32Scale_2, DECIMAL32_2)) { try (ColumnVector answer = icv.notEqualTo(dcv); ColumnVector expected = forEach(DType.BOOL8, icv, dcv, (b, l, r, i) -> b.append(l.getInt(i) != r.getDouble(i)))) { assertColumnsAreEqual(expected, answer, "int32 != double"); } + try (ColumnVector answer = dec32cv_1.notEqualTo(dec32cv_2); + ColumnVector expected = forEach(DType.BOOL8, dec32cv_1, dec32cv_2, + (b, l, r, i) -> b.append(l.getBigDecimal(i).compareTo(r.getBigDecimal(i)) != 0))) { + assertColumnsAreEqual(expected, answer, "dec32 != dec32 "); + } + + try (Scalar s = Scalar.fromDecimal(-2, 200); + ColumnVector answer = dec32cv_2.notEqualTo(s)) { + try (ColumnVector expected = forEachS(DType.BOOL8, dec32cv_1, BigDecimal.valueOf(200, 2), + (b, l, r, i) -> b.append(l.getBigDecimal(i).compareTo(r) != 0))) { + assertColumnsAreEqual(expected, answer, "dec32 != scalar dec32"); + } + } + try (Scalar s = Scalar.fromFloat(1.0f); ColumnVector answer = icv.notEqualTo(s); ColumnVector expected = forEachS(DType.BOOL8, icv, 1.0f, @@ -743,13 +921,21 @@ public void testStringNotEqualScalarNotPresent() { @Test public void testLessThan() { try (ColumnVector icv = ColumnVector.fromBoxedInts(INTS_1); - ColumnVector dcv = ColumnVector.fromBoxedDoubles(DOUBLES_1)) { + ColumnVector dcv = ColumnVector.fromBoxedDoubles(DOUBLES_1); + ColumnVector dec32cv_1 = ColumnVector.decimalFromInts(-dec32Scale_1, DECIMAL32_1); + ColumnVector dec32cv_2 = ColumnVector.decimalFromInts(-dec32Scale_2, DECIMAL32_2)) { try (ColumnVector answer = icv.lessThan(dcv); ColumnVector expected = forEach(DType.BOOL8, icv, dcv, (b, l, r, i) -> b.append(l.getInt(i) < r.getDouble(i)))) { assertColumnsAreEqual(expected, answer, "int32 < double"); } + try (ColumnVector answer = dec32cv_1.lessThan(dec32cv_2); + ColumnVector expected = forEach(DType.BOOL8, dec32cv_1, dec32cv_2, + (b, l, r, i) -> b.append(l.getBigDecimal(i).compareTo(r.getBigDecimal(i)) < 0))) { + assertColumnsAreEqual(expected, answer, "dec32 < dec32 "); + } + try (Scalar s = Scalar.fromFloat(1.0f); ColumnVector answer = icv.lessThan(s); ColumnVector expected = forEachS(DType.BOOL8, icv, 1.0f, @@ -818,13 +1004,21 @@ public void testStringLessThanScalarNotPresent() { @Test public void testGreaterThan() { try (ColumnVector icv = ColumnVector.fromBoxedInts(INTS_1); - ColumnVector dcv = ColumnVector.fromBoxedDoubles(DOUBLES_1)) { + ColumnVector dcv = ColumnVector.fromBoxedDoubles(DOUBLES_1); + ColumnVector dec32cv1 = ColumnVector.fromDecimals(BIGDECIMAL32_1); + ColumnVector dec32cv2 = ColumnVector.fromDecimals(BIGDECIMAL32_2)) { try (ColumnVector answer = icv.greaterThan(dcv); ColumnVector expected = forEach(DType.BOOL8, icv, dcv, (b, l, r, i) -> b.append(l.getInt(i) > r.getDouble(i)))) { assertColumnsAreEqual(expected, answer, "int32 > double"); } + try (ColumnVector answer = dec32cv2.greaterThan(dec32cv1); + ColumnVector expected = forEach(DType.BOOL8, dec32cv2, dec32cv1, + (b, l, r, i) -> b.append(l.getBigDecimal(i).compareTo(r.getBigDecimal(i)) > 0))) { + assertColumnsAreEqual(expected, answer, "dec32 > dec32 "); + } + try (Scalar s = Scalar.fromFloat(1.0f); ColumnVector answer = icv.greaterThan(s); ColumnVector expected = forEachS(DType.BOOL8, icv, 1.0f, @@ -892,7 +1086,8 @@ public void testStringGreaterThanScalarNotPresent() { @Test public void testLessOrEqualTo() { try (ColumnVector icv = ColumnVector.fromBoxedInts(INTS_1); - ColumnVector dcv = ColumnVector.fromBoxedDoubles(DOUBLES_1)) { + ColumnVector dcv = ColumnVector.fromBoxedDoubles(DOUBLES_1); + ColumnVector dec32cv = ColumnVector.decimalFromInts(-dec32Scale_2, DECIMAL32_2)) { try (ColumnVector answer = icv.lessOrEqualTo(dcv); ColumnVector expected = forEach(DType.BOOL8, icv, dcv, (b, l, r, i) -> b.append(l.getInt(i) <= r.getDouble(i)))) { @@ -912,6 +1107,14 @@ public void testLessOrEqualTo() { (b, l, r, i) -> b.append(l <= r.getInt(i)))) { assertColumnsAreEqual(expected, answer, "scalar short <= int32"); } + + try (Scalar s = Scalar.fromDecimal(-2, 200); + ColumnVector answer = dec32cv.lessOrEqualTo(s)) { + try (ColumnVector expected = forEachS(DType.BOOL8, dec32cv, BigDecimal.valueOf(200, 2), + (b, l, r, i) -> b.append(l.getBigDecimal(i).compareTo(r) <= 0))) { + assertColumnsAreEqual(expected, answer, "dec32 <= scalar dec32"); + } + } } } @@ -966,7 +1169,8 @@ public void testStringLessOrEqualToScalarNotPresent() { @Test public void testGreaterOrEqualTo() { try (ColumnVector icv = ColumnVector.fromBoxedInts(INTS_1); - ColumnVector dcv = ColumnVector.fromBoxedDoubles(DOUBLES_1)) { + ColumnVector dcv = ColumnVector.fromBoxedDoubles(DOUBLES_1); + ColumnVector dec32cv = ColumnVector.decimalFromInts(-dec32Scale_2, DECIMAL32_2)) { try (ColumnVector answer = icv.greaterOrEqualTo(dcv); ColumnVector expected = forEach(DType.BOOL8, icv, dcv, (b, l, r, i) -> b.append(l.getInt(i) >= r.getDouble(i)))) { @@ -986,6 +1190,14 @@ public void testGreaterOrEqualTo() { (b, l, r, i) -> b.append(l >= r.getInt(i)))) { assertColumnsAreEqual(expected, answer, "scalar short >= int32"); } + + try (Scalar s = Scalar.fromDecimal(-2, 200); + ColumnVector answer = dec32cv.greaterOrEqualTo(s)) { + try (ColumnVector expected = forEachS(DType.BOOL8, dec32cv, BigDecimal.valueOf(200, 2), + (b, l, r, i) -> b.append(l.getBigDecimal(i).compareTo(r) >= 0))) { + assertColumnsAreEqual(expected, answer, "dec32 >= scalar dec32"); + } + } } } @@ -1416,4 +1628,17 @@ public void testMinNullAware() { } } + @Test + public void testDecimalTypeThrowsException() { + try (ColumnVector dec64cv1 = ColumnVector.decimalFromLongs(-dec64Scale_1+10, DECIMAL64_1); + ColumnVector dec64cv2 = ColumnVector.decimalFromLongs(-dec64Scale_2- 10 , DECIMAL64_2)) { + assertThrows(ArithmeticException.class, + () -> { + try (ColumnVector expected = forEach + (DType.create(DType.DTypeEnum.DECIMAL64, -6), dec64cv1, dec64cv2, + (b, l, r, i) -> b.append(l.getBigDecimal(i).add(r.getBigDecimal(i))))) { + } + }); + } + } } diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index c7bc7989be9..97ceb23c837 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -195,6 +195,7 @@ public static void assertPartialColumnsAreEqual(HostColumnVectorCore expected, l case UINT32: // fall through case TIMESTAMP_DAYS: case DURATION_DAYS: + case DECIMAL32: assertEquals(expected.getInt(expectedRow), cv.getInt(tableRow), "Column " + colName + " Row " + tableRow); break; @@ -208,6 +209,7 @@ public static void assertPartialColumnsAreEqual(HostColumnVectorCore expected, l case TIMESTAMP_MILLISECONDS: // fall through case TIMESTAMP_NANOSECONDS: // fall through case TIMESTAMP_SECONDS: + case DECIMAL64: assertEquals(expected.getLong(expectedRow), cv.getLong(tableRow), "Column " + colName + " Row " + tableRow); break;