Skip to content

Commit

Permalink
Add Java tests for decimal casts(#7051)
Browse files Browse the repository at this point in the history
This pull request attempts to verify the support of decimal cast in terms of java package.

Authors:
  - sperlingxx <[email protected]>

Approvers:
  - Jason Lowe (@jlowe)

URL: #7051
  • Loading branch information
sperlingxx authored Jan 8, 2021
1 parent aa38f85 commit 30e154c
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
- PR #6939 Use simplified `rmm::exec_policy`
- PR #6512 Refactor rolling.cu to reduce compile time
- PR #6982 Disable some pragma unroll statements in thrust `sort.h`
- PR #7051 Verify decimal cast in java package

## Bug Fixes

Expand Down
131 changes: 131 additions & 0 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -2215,6 +2215,137 @@ private static void testCastFixedWidthToStringsAndBack(DType type, Supplier<Colu
}
}

@Test
void testCastIntToDecimal() {
testCastNumericToDecimalsAndBack(DType.INT32, true, 0,
() -> ColumnVector.fromBoxedInts(1, -21, 345, null, 8008, Integer.MIN_VALUE, Integer.MAX_VALUE),
() -> ColumnVector.fromBoxedInts(1, -21, 345, null, 8008, Integer.MIN_VALUE, Integer.MAX_VALUE),
new Long[]{1L, -21L, 345L, null, 8008L, (long) Integer.MIN_VALUE, (long) Integer.MAX_VALUE}
);
testCastNumericToDecimalsAndBack(DType.INT32, false, -2,
() -> ColumnVector.fromBoxedInts(1, -21, 345, null, 8008, 0, 123456),
() -> ColumnVector.fromBoxedInts(1, -21, 345, null, 8008, 0, 123456),
new Long[]{100L, -2100L, 34500L, null, 800800L, 0L, 12345600L}
);
testCastNumericToDecimalsAndBack(DType.INT32, false, 2,
() -> ColumnVector.fromBoxedInts(1, -21, 345, null, 8008, 0, 123456),
() -> ColumnVector.fromBoxedInts(0, 0, 300, null, 8000, 0, 123400),
new Long[]{0L, 0L, 3L, null, 80L, 0L, 1234L}
);
}

@Test
void testCastLongToDecimal() {
testCastNumericToDecimalsAndBack(DType.INT64, false, 0,
() -> ColumnVector.fromBoxedLongs(1L, -21L, 345L, null, 8008L, Long.MIN_VALUE, Long.MAX_VALUE),
() -> ColumnVector.fromBoxedLongs(1L, -21L, 345L, null, 8008L, Long.MIN_VALUE, Long.MAX_VALUE),
new Long[]{1L, -21L, 345L, null, 8008L, Long.MIN_VALUE, Long.MAX_VALUE}
);
testCastNumericToDecimalsAndBack(DType.INT64, false, -1,
() -> ColumnVector.fromBoxedLongs(1L, -21L, 345L, null, 8008L, 0L, 123456L),
() -> ColumnVector.fromBoxedLongs(1L, -21L, 345L, null, 8008L, 0L, 123456L),
new Long[]{10L, -210L, 3450L, null, 80080L, 0L, 1234560L}
);
testCastNumericToDecimalsAndBack(DType.INT64, false, 1,
() -> ColumnVector.fromBoxedLongs(1L, -21L, 345L, null, 8018L, 0L, 123456L),
() -> ColumnVector.fromBoxedLongs(0L, -20L, 340L, null, 8010L, 0L, 123450L),
new Long[]{0L, -2L, 34L, null, 801L, 0L, 12345L}
);
}

@Test
void testCastFloatToDecimal() {
testCastNumericToDecimalsAndBack(DType.FLOAT32, true, 0,
() -> ColumnVector.fromBoxedFloats(1.0f, 2.1f, -3.23f, null, 2.41281f, 1378952.001f),
() -> ColumnVector.fromBoxedFloats(1f, 2f, -3f, null, 2f, 1378952f),
new Long[]{1L, 2L, -3L, null, 2L, 1378952L}
);
testCastNumericToDecimalsAndBack(DType.FLOAT32, true, -1,
() -> ColumnVector.fromBoxedFloats(1.0f, 2.1f, -3.23f, null, 2.41281f, 1378952.001f),
() -> ColumnVector.fromBoxedFloats(1f, 2.1f, -3.2f, null, 2.4f, 1378952f),
new Long[]{10L, 21L, -32L, null, 24L, 13789520L}
);
testCastNumericToDecimalsAndBack(DType.FLOAT32, true, 1,
() -> ColumnVector.fromBoxedFloats(1.0f, 21.1f, -300.23f, null, 24128.1f, 1378952.001f),
() -> ColumnVector.fromBoxedFloats(0f, 20f, -300f, null, 24120f, 1378950f),
new Long[]{0L, 2L, -30L, null, 2412L, 137895L}
);
}

@Test
void testCastDoubleToDecimal() {
testCastNumericToDecimalsAndBack(DType.FLOAT64, false, 0,
() -> ColumnVector.fromBoxedDoubles(1.0, 2.1, -3.23, null, 2.41281, (double) Long.MAX_VALUE),
() -> ColumnVector.fromBoxedDoubles(1.0, 2.0, -3.0, null, 2.0, (double) Long.MAX_VALUE),
new Long[]{1L, 2L, -3L, null, 2L, Long.MAX_VALUE}
);
testCastNumericToDecimalsAndBack(DType.FLOAT64, false, -2,
() -> ColumnVector.fromBoxedDoubles(1.0, 2.1, -3.23, null, 2.41281, -55.01999),
() -> ColumnVector.fromBoxedDoubles(1.0, 2.1, -3.23, null, 2.41, -55.01),
new Long[]{100L, 210L, -323L, null, 241L, -5501L}
);
testCastNumericToDecimalsAndBack(DType.FLOAT64, false, 1,
() -> ColumnVector.fromBoxedDoubles(1.0, 23.1, -3089.23, null, 200.41281, -199.01999),
() -> ColumnVector.fromBoxedDoubles(0.0, 20.0, -3080.0, null, 200.0, -190.0),
new Long[]{0L, 2L, -308L, null, 20L, -19L}
);
}

@Test
void testCastDecimalToDecimal() {
// DECIMAL32(scale: 0) -> DECIMAL32(scale: 0)
testCastNumericToDecimalsAndBack(DType.create(DType.DTypeEnum.DECIMAL32, 0), true, -0,
() -> ColumnVector.decimalFromInts(0, 1, 12, -234, 5678, Integer.MIN_VALUE / 100),
() -> ColumnVector.decimalFromInts(0, 1, 12, -234, 5678, Integer.MIN_VALUE / 100),
new Long[]{1L, 12L, -234L, 5678L, (long) Integer.MIN_VALUE / 100}
);
// DECIMAL32(scale: 0) -> DECIMAL64(scale: -2)
testCastNumericToDecimalsAndBack(DType.create(DType.DTypeEnum.DECIMAL32, 0), false, -2,
() -> ColumnVector.decimalFromInts(0, 1, 12, -234, 5678, Integer.MIN_VALUE / 100),
() -> ColumnVector.decimalFromInts(0, 1, 12, -234, 5678, Integer.MIN_VALUE / 100),
new Long[]{100L, 1200L, -23400L, 567800L, (long) Integer.MIN_VALUE / 100 * 100}
);
// DECIMAL64(scale: -3) -> DECIMAL64(scale: -1)
DType dt = DType.create(DType.DTypeEnum.DECIMAL64, -3);
testCastNumericToDecimalsAndBack(dt, false, -1,
() -> ColumnVector.decimalFromDoubles(dt, RoundingMode.UNNECESSARY, -1000.1, 1.222, 0.03, -4.678, 16789431.0),
() -> ColumnVector.decimalFromDoubles(dt, RoundingMode.UNNECESSARY, -1000.1, 1.2, 0, -4.6, 16789431.0),
new Long[]{-10001L, 12L, 0L, -46L, 167894310L}
);
// DECIMAL64(scale: -3) -> DECIMAL64(scale: 2)
DType dt2 = DType.create(DType.DTypeEnum.DECIMAL64, -3);
testCastNumericToDecimalsAndBack(dt2, false, 2,
() -> ColumnVector.decimalFromDoubles(dt2, RoundingMode.UNNECESSARY, -1013.1, 14.222, 780.03, -4.678, 16789431.0),
() -> ColumnVector.decimalFromDoubles(dt2, RoundingMode.UNNECESSARY, -1000, 0, 700, 0, 16789400),
new Long[]{-10L, 0L, 7L, 0L, 167894L}
);
// DECIMAL64(scale: -3) -> DECIMAL32(scale: -3)
testCastNumericToDecimalsAndBack(dt2, true, -3,
() -> ColumnVector.decimalFromDoubles(dt2, RoundingMode.UNNECESSARY, -1013.1, 14.222, 780.03, -4.678, 16789.0),
() -> ColumnVector.decimalFromDoubles(dt2, RoundingMode.UNNECESSARY, -1013.1, 14.222, 780.03, -4.678, 16789.0),
new Long[]{-1013100L, 14222L, 780030L, -4678L, 16789000L}
);
}

private static void testCastNumericToDecimalsAndBack(DType sourceType, boolean isDec32, int scale,
Supplier<ColumnVector> sourceData,
Supplier<ColumnVector> returnData,
Long[] unscaledDecimal) {
DType decimalType = DType.create(isDec32 ? DType.DTypeEnum.DECIMAL32 : DType.DTypeEnum.DECIMAL64, scale);
try (ColumnVector sourceColumn = sourceData.get();
ColumnVector expectedColumn = returnData.get();
ColumnVector decimalColumn = sourceColumn.castTo(decimalType);
HostColumnVector hostDecimalColumn = decimalColumn.copyToHost();
ColumnVector returnColumn = decimalColumn.castTo(sourceType)) {
for (int i = 0; i < sourceColumn.rows; i++) {
Long actual = hostDecimalColumn.isNull(i) ? null :
(isDec32 ? hostDecimalColumn.getInt(i) : hostDecimalColumn.getLong(i));
assertEquals(unscaledDecimal[i], actual);
}
assertColumnsAreEqual(expectedColumn, returnColumn);
}
}

@Test
void testIsTimestamp() {
final String[] TIMESTAMP_STRINGS = {
Expand Down

0 comments on commit 30e154c

Please sign in to comment.