diff --git a/CHANGELOG.md b/CHANGELOG.md index e895a657de8..04ffc5868e8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ - PR #6939 Use simplified `rmm::exec_policy` - PR #6512 Refactor rolling.cu to reduce compile time - PR #6982 Disable some pragma unroll statements in thrust `sort.h` +- PR #7051 Verify decimal cast in java package ## Bug Fixes diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index 54646583831..738bacfe130 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -2215,6 +2215,137 @@ private static void testCastFixedWidthToStringsAndBack(DType type, Supplier ColumnVector.fromBoxedInts(1, -21, 345, null, 8008, Integer.MIN_VALUE, Integer.MAX_VALUE), + () -> ColumnVector.fromBoxedInts(1, -21, 345, null, 8008, Integer.MIN_VALUE, Integer.MAX_VALUE), + new Long[]{1L, -21L, 345L, null, 8008L, (long) Integer.MIN_VALUE, (long) Integer.MAX_VALUE} + ); + testCastNumericToDecimalsAndBack(DType.INT32, false, -2, + () -> ColumnVector.fromBoxedInts(1, -21, 345, null, 8008, 0, 123456), + () -> ColumnVector.fromBoxedInts(1, -21, 345, null, 8008, 0, 123456), + new Long[]{100L, -2100L, 34500L, null, 800800L, 0L, 12345600L} + ); + testCastNumericToDecimalsAndBack(DType.INT32, false, 2, + () -> ColumnVector.fromBoxedInts(1, -21, 345, null, 8008, 0, 123456), + () -> ColumnVector.fromBoxedInts(0, 0, 300, null, 8000, 0, 123400), + new Long[]{0L, 0L, 3L, null, 80L, 0L, 1234L} + ); + } + + @Test + void testCastLongToDecimal() { + testCastNumericToDecimalsAndBack(DType.INT64, false, 0, + () -> ColumnVector.fromBoxedLongs(1L, -21L, 345L, null, 8008L, Long.MIN_VALUE, Long.MAX_VALUE), + () -> ColumnVector.fromBoxedLongs(1L, -21L, 345L, null, 8008L, Long.MIN_VALUE, Long.MAX_VALUE), + new Long[]{1L, -21L, 345L, null, 8008L, Long.MIN_VALUE, Long.MAX_VALUE} + ); + testCastNumericToDecimalsAndBack(DType.INT64, false, -1, + () -> ColumnVector.fromBoxedLongs(1L, -21L, 345L, null, 8008L, 0L, 123456L), + () -> ColumnVector.fromBoxedLongs(1L, -21L, 345L, null, 8008L, 0L, 123456L), + new Long[]{10L, -210L, 3450L, null, 80080L, 0L, 1234560L} + ); + testCastNumericToDecimalsAndBack(DType.INT64, false, 1, + () -> ColumnVector.fromBoxedLongs(1L, -21L, 345L, null, 8018L, 0L, 123456L), + () -> ColumnVector.fromBoxedLongs(0L, -20L, 340L, null, 8010L, 0L, 123450L), + new Long[]{0L, -2L, 34L, null, 801L, 0L, 12345L} + ); + } + + @Test + void testCastFloatToDecimal() { + testCastNumericToDecimalsAndBack(DType.FLOAT32, true, 0, + () -> ColumnVector.fromBoxedFloats(1.0f, 2.1f, -3.23f, null, 2.41281f, 1378952.001f), + () -> ColumnVector.fromBoxedFloats(1f, 2f, -3f, null, 2f, 1378952f), + new Long[]{1L, 2L, -3L, null, 2L, 1378952L} + ); + testCastNumericToDecimalsAndBack(DType.FLOAT32, true, -1, + () -> ColumnVector.fromBoxedFloats(1.0f, 2.1f, -3.23f, null, 2.41281f, 1378952.001f), + () -> ColumnVector.fromBoxedFloats(1f, 2.1f, -3.2f, null, 2.4f, 1378952f), + new Long[]{10L, 21L, -32L, null, 24L, 13789520L} + ); + testCastNumericToDecimalsAndBack(DType.FLOAT32, true, 1, + () -> ColumnVector.fromBoxedFloats(1.0f, 21.1f, -300.23f, null, 24128.1f, 1378952.001f), + () -> ColumnVector.fromBoxedFloats(0f, 20f, -300f, null, 24120f, 1378950f), + new Long[]{0L, 2L, -30L, null, 2412L, 137895L} + ); + } + + @Test + void testCastDoubleToDecimal() { + testCastNumericToDecimalsAndBack(DType.FLOAT64, false, 0, + () -> ColumnVector.fromBoxedDoubles(1.0, 2.1, -3.23, null, 2.41281, (double) Long.MAX_VALUE), + () -> ColumnVector.fromBoxedDoubles(1.0, 2.0, -3.0, null, 2.0, (double) Long.MAX_VALUE), + new Long[]{1L, 2L, -3L, null, 2L, Long.MAX_VALUE} + ); + testCastNumericToDecimalsAndBack(DType.FLOAT64, false, -2, + () -> ColumnVector.fromBoxedDoubles(1.0, 2.1, -3.23, null, 2.41281, -55.01999), + () -> ColumnVector.fromBoxedDoubles(1.0, 2.1, -3.23, null, 2.41, -55.01), + new Long[]{100L, 210L, -323L, null, 241L, -5501L} + ); + testCastNumericToDecimalsAndBack(DType.FLOAT64, false, 1, + () -> ColumnVector.fromBoxedDoubles(1.0, 23.1, -3089.23, null, 200.41281, -199.01999), + () -> ColumnVector.fromBoxedDoubles(0.0, 20.0, -3080.0, null, 200.0, -190.0), + new Long[]{0L, 2L, -308L, null, 20L, -19L} + ); + } + + @Test + void testCastDecimalToDecimal() { + // DECIMAL32(scale: 0) -> DECIMAL32(scale: 0) + testCastNumericToDecimalsAndBack(DType.create(DType.DTypeEnum.DECIMAL32, 0), true, -0, + () -> ColumnVector.decimalFromInts(0, 1, 12, -234, 5678, Integer.MIN_VALUE / 100), + () -> ColumnVector.decimalFromInts(0, 1, 12, -234, 5678, Integer.MIN_VALUE / 100), + new Long[]{1L, 12L, -234L, 5678L, (long) Integer.MIN_VALUE / 100} + ); + // DECIMAL32(scale: 0) -> DECIMAL64(scale: -2) + testCastNumericToDecimalsAndBack(DType.create(DType.DTypeEnum.DECIMAL32, 0), false, -2, + () -> ColumnVector.decimalFromInts(0, 1, 12, -234, 5678, Integer.MIN_VALUE / 100), + () -> ColumnVector.decimalFromInts(0, 1, 12, -234, 5678, Integer.MIN_VALUE / 100), + new Long[]{100L, 1200L, -23400L, 567800L, (long) Integer.MIN_VALUE / 100 * 100} + ); + // DECIMAL64(scale: -3) -> DECIMAL64(scale: -1) + DType dt = DType.create(DType.DTypeEnum.DECIMAL64, -3); + testCastNumericToDecimalsAndBack(dt, false, -1, + () -> ColumnVector.decimalFromDoubles(dt, RoundingMode.UNNECESSARY, -1000.1, 1.222, 0.03, -4.678, 16789431.0), + () -> ColumnVector.decimalFromDoubles(dt, RoundingMode.UNNECESSARY, -1000.1, 1.2, 0, -4.6, 16789431.0), + new Long[]{-10001L, 12L, 0L, -46L, 167894310L} + ); + // DECIMAL64(scale: -3) -> DECIMAL64(scale: 2) + DType dt2 = DType.create(DType.DTypeEnum.DECIMAL64, -3); + testCastNumericToDecimalsAndBack(dt2, false, 2, + () -> ColumnVector.decimalFromDoubles(dt2, RoundingMode.UNNECESSARY, -1013.1, 14.222, 780.03, -4.678, 16789431.0), + () -> ColumnVector.decimalFromDoubles(dt2, RoundingMode.UNNECESSARY, -1000, 0, 700, 0, 16789400), + new Long[]{-10L, 0L, 7L, 0L, 167894L} + ); + // DECIMAL64(scale: -3) -> DECIMAL32(scale: -3) + testCastNumericToDecimalsAndBack(dt2, true, -3, + () -> ColumnVector.decimalFromDoubles(dt2, RoundingMode.UNNECESSARY, -1013.1, 14.222, 780.03, -4.678, 16789.0), + () -> ColumnVector.decimalFromDoubles(dt2, RoundingMode.UNNECESSARY, -1013.1, 14.222, 780.03, -4.678, 16789.0), + new Long[]{-1013100L, 14222L, 780030L, -4678L, 16789000L} + ); + } + + private static void testCastNumericToDecimalsAndBack(DType sourceType, boolean isDec32, int scale, + Supplier sourceData, + Supplier returnData, + Long[] unscaledDecimal) { + DType decimalType = DType.create(isDec32 ? DType.DTypeEnum.DECIMAL32 : DType.DTypeEnum.DECIMAL64, scale); + try (ColumnVector sourceColumn = sourceData.get(); + ColumnVector expectedColumn = returnData.get(); + ColumnVector decimalColumn = sourceColumn.castTo(decimalType); + HostColumnVector hostDecimalColumn = decimalColumn.copyToHost(); + ColumnVector returnColumn = decimalColumn.castTo(sourceType)) { + for (int i = 0; i < sourceColumn.rows; i++) { + Long actual = hostDecimalColumn.isNull(i) ? null : + (isDec32 ? hostDecimalColumn.getInt(i) : hostDecimalColumn.getLong(i)); + assertEquals(unscaledDecimal[i], actual); + } + assertColumnsAreEqual(expectedColumn, returnColumn); + } + } + @Test void testIsTimestamp() { final String[] TIMESTAMP_STRINGS = {