diff --git a/CHANGELOG.md b/CHANGELOG.md index 2d311431a55..f6321bec199 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -84,6 +84,7 @@ - PR #6761 Add Java/JNI bindings for round - PR #6786 Add nested type support to ColumnVector#getDeviceMemorySize - PR #6780 Move `cudf::cast` tests to separate test file +- PR #6770 Support building decimal columns with Table.TestBuilder ## Bug Fixes diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index c059cf9c175..d31eab3a8f0 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -25,6 +25,8 @@ import ai.rapids.cudf.HostColumnVector.StructType; import java.io.File; +import java.math.BigDecimal; +import java.math.RoundingMode; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -2510,6 +2512,58 @@ public TestBuilder timestampSecondsColumn(Long... values) { return this; } + public TestBuilder decimal32Column(int scale, Integer... unscaledValues) { + types.add(new BasicType(true, DType.create(DType.DTypeEnum.DECIMAL32, scale))); + typeErasedData.add(unscaledValues); + return this; + } + + public TestBuilder decimal32Column(int scale, RoundingMode mode, Double... values) { + types.add(new BasicType(true, DType.create(DType.DTypeEnum.DECIMAL32, scale))); + BigDecimal[] data = Arrays.stream(values).map((x) -> { + if (x == null) return null; + return BigDecimal.valueOf(x).setScale(-scale, mode); + }).toArray(BigDecimal[]::new); + typeErasedData.add(data); + return this; + } + + public TestBuilder decimal32Column(int scale, RoundingMode mode, String... values) { + types.add(new BasicType(true, DType.create(DType.DTypeEnum.DECIMAL32, scale))); + BigDecimal[] data = Arrays.stream(values).map((x) -> { + if (x == null) return null; + return new BigDecimal(x).setScale(-scale, mode); + }).toArray(BigDecimal[]::new); + typeErasedData.add(data); + return this; + } + + public TestBuilder decimal64Column(int scale, Long... unscaledValues) { + types.add(new BasicType(true, DType.create(DType.DTypeEnum.DECIMAL64, scale))); + typeErasedData.add(unscaledValues); + return this; + } + + public TestBuilder decimal64Column(int scale, RoundingMode mode, Double... values) { + types.add(new BasicType(true, DType.create(DType.DTypeEnum.DECIMAL64, scale))); + BigDecimal[] data = Arrays.stream(values).map((x) -> { + if (x == null) return null; + return BigDecimal.valueOf(x).setScale(-scale, mode); + }).toArray(BigDecimal[]::new); + typeErasedData.add(data); + return this; + } + + public TestBuilder decimal64Column(int scale, RoundingMode mode, String... values) { + types.add(new BasicType(true, DType.create(DType.DTypeEnum.DECIMAL64, scale))); + BigDecimal[] data = Arrays.stream(values).map((x) -> { + if (x == null) return null; + return new BigDecimal(x).setScale(-scale, mode); + }).toArray(BigDecimal[]::new); + typeErasedData.add(data); + return this; + } + private static ColumnVector from(DType type, Object dataArray) { ColumnVector ret = null; switch (type.typeId) { @@ -2552,6 +2606,27 @@ private static ColumnVector from(DType type, Object dataArray) { case FLOAT64: ret = ColumnVector.fromBoxedDoubles((Double[]) dataArray); break; + case DECIMAL32: + case DECIMAL64: + int scale = type.getScale(); + if (dataArray instanceof Integer[]) { + BigDecimal[] data = Arrays.stream(((Integer[]) dataArray)) + .map((i) -> i == null ? null : BigDecimal.valueOf(i, -scale)) + .toArray(BigDecimal[]::new); + ret = ColumnVector.build(type, data.length, (b) -> b.appendBoxed(data)); + } else if (dataArray instanceof Long[]) { + BigDecimal[] data = Arrays.stream(((Long[]) dataArray)) + .map((i) -> i == null ? null : BigDecimal.valueOf(i, -scale)) + .toArray(BigDecimal[]::new); + ret = ColumnVector.build(type, data.length, (b) -> b.appendBoxed(data)); + } else if (dataArray instanceof BigDecimal[]) { + BigDecimal[] data = (BigDecimal[]) dataArray; + ret = ColumnVector.build(type, data.length, (b) -> b.appendBoxed(data)); + } else { + throw new IllegalArgumentException( + "Data array of invalid type(" + dataArray.getClass() + ") to build decimal column"); + } + break; default: throw new IllegalArgumentException(type + " is not supported yet"); } diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index 97ceb23c837..fae2d4694d3 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -33,6 +33,7 @@ import java.io.File; import java.io.FileInputStream; import java.io.IOException; +import java.math.RoundingMode; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.util.ArrayList; @@ -1452,10 +1453,14 @@ void testRepeat() { try (Table t = new Table.TestBuilder() .column(1, 2) .column("a", "b") + .decimal32Column(-3, 12, -25) + .decimal64Column(2, 11111L, -22222L) .build(); Table expected = new Table.TestBuilder() .column(1, 1, 1, 2, 2, 2) .column("a", "a", "a", "b", "b", "b") + .decimal32Column(-3, 12, 12, 12, -25, -25, -25) + .decimal64Column(2, 11111L, 11111L, 11111L, -22222L, -22222L, -22222L) .build(); Table repeated = t.repeat(3)) { assertTablesAreEqual(expected, repeated); @@ -1467,11 +1472,15 @@ void testRepeatColumn() { try (Table t = new Table.TestBuilder() .column(1, 2) .column("a", "b") + .decimal32Column(-3, 12, -25) + .decimal64Column(2, 11111L, -22222L) .build(); ColumnVector repeats = ColumnVector.fromBytes((byte)1, (byte)4); Table expected = new Table.TestBuilder() .column(1, 2, 2, 2, 2) .column("a", "b", "b", "b", "b") + .decimal32Column(-3, 12, -25, -25, -25, -25) + .decimal64Column(2, 11111L, -22222L, -22222L, -22222L, -22222L) .build(); Table repeated = t.repeat(repeats)) { assertTablesAreEqual(expected, repeated); @@ -1513,6 +1522,26 @@ void testInterleaveFloatColumns() { } } + @Test + void testInterleaveDecimalColumns() { + try (Table t = new Table.TestBuilder() + .decimal32Column(-2, 123, 456, 789) + .decimal32Column(-2,-100, -200, -300) + .build(); + ColumnVector expected = ColumnVector.decimalFromInts(-2, 123, -100, 456, -200, 789, -300); + ColumnVector actual = t.interleaveColumns()) { + assertColumnsAreEqual(expected, actual); + } + try (Table t = new Table.TestBuilder() + .decimal64Column(-5, 123456790L, 987654321L) + .decimal64Column(-5,-123456790L, -987654321L) + .build(); + ColumnVector expected = ColumnVector.decimalFromLongs(-5, 123456790L, -123456790L, 987654321L, -987654321L); + ColumnVector actual = t.interleaveColumns()) { + assertColumnsAreEqual(expected, actual); + } + } + @Test void testInterleaveStringColumns() { try (Table t = new Table.TestBuilder() @@ -1542,23 +1571,35 @@ void testConcatNoNulls() { .column(1, 2, 3) .column("1", "2", "3") .timestampMicrosecondsColumn(1L, 2L, 3L) - .column(11.0, 12.0, 13.0).build(); + .column(11.0, 12.0, 13.0) + .decimal32Column(-3, 1, 2, 3) + .decimal64Column(-10, 1L, 2L, 3L) + .build(); Table t2 = new Table.TestBuilder() .column(4, 5) .column("4", "3") .timestampMicrosecondsColumn(4L, 3L) - .column(14.0, 15.0).build(); + .column(14.0, 15.0) + .decimal32Column(-3, 4, 5) + .decimal64Column(-10, 4L, 5L) + .build(); Table t3 = new Table.TestBuilder() .column(6, 7, 8, 9) .column("4", "1", "2", "2") .timestampMicrosecondsColumn(4L, 1L, 2L, 2L) - .column(16.0, 17.0, 18.0, 19.0).build(); + .column(16.0, 17.0, 18.0, 19.0) + .decimal32Column(-3, 6, 7, 8, 9) + .decimal64Column(-10, 6L, 7L, 8L, 9L) + .build(); Table concat = Table.concatenate(t1, t2, t3); Table expected = new Table.TestBuilder() .column(1, 2, 3, 4, 5, 6, 7, 8, 9) .column("1", "2", "3", "4", "3", "4", "1", "2", "2") .timestampMicrosecondsColumn(1L, 2L, 3L, 4L, 3L, 4L, 1L, 2L, 2L) - .column(11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0).build()) { + .column(11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0) + .decimal32Column(-3, 1, 2, 3, 4, 5, 6, 7, 8, 9) + .decimal64Column(-10, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L) + .build()) { assertTablesAreEqual(expected, concat); } } @@ -1567,17 +1608,29 @@ void testConcatNoNulls() { void testConcatWithNulls() { try (Table t1 = new Table.TestBuilder() .column(1, null, 3) - .column(11.0, 12.0, 13.0).build(); + .column(11.0, 12.0, 13.0) + .decimal32Column(-3, 1, null, 3) + .decimal64Column(-10, 11L, 12L, 13L) + .build(); Table t2 = new Table.TestBuilder() .column(4, null) - .column(14.0, 15.0).build(); + .column(14.0, 15.0) + .decimal32Column(-3, 4, null) + .decimal64Column(-10, 14L, 15L) + .build(); Table t3 = new Table.TestBuilder() .column(6, 7, 8, 9) - .column(null, null, 18.0, 19.0).build(); + .column(null, null, 18.0, 19.0) + .decimal32Column(-3, 6, 7, 8, 9) + .decimal64Column(-10, null, null, 18L, 19L) + .build(); Table concat = Table.concatenate(t1, t2, t3); Table expected = new Table.TestBuilder() .column(1, null, 3, 4, null, 6, 7, 8, 9) - .column(11.0, 12.0, 13.0, 14.0, 15.0, null, null, 18.0, 19.0).build()) { + .column(11.0, 12.0, 13.0, 14.0, 15.0, null, null, 18.0, 19.0) + .decimal32Column(-3, 1, null, 3, 4, null, 6, 7, 8, 9) + .decimal64Column(-10, 11L, 12L, 13L, 14L, 15L, null, null, 18L, 19L) + .build()) { assertTablesAreEqual(expected, concat); } } @@ -1588,6 +1641,8 @@ void testContiguousSplit() { try (Table t1 = new Table.TestBuilder() .column(10, 12, 14, 16, 18, 20, 22, 24, null, 28) .column(50, 52, 54, 56, 58, 60, 62, 64, 66, null) + .decimal32Column(-3, 10, 12, 14, 16, 18, 20, 22, 24, null, 28) + .decimal64Column(-8, 50L, 52L, 54L, 56L, 58L, 60L, 62L, 64L, 66L, null) .build()) { splits = t1.contiguousSplit(2, 5, 9); assertEquals(4, splits.length); @@ -1611,6 +1666,8 @@ void testContiguousSplitWithStrings() { .column(10, 12, 14, 16, 18, 20, 22, 24, null, 28) .column(50, 52, 54, 56, 58, 60, 62, 64, 66, null) .column("A", "B", "C", "D", "E", "F", "G", "H", "I", "J") + .decimal32Column(-3, 10, 12, 14, 16, 18, 20, 22, 24, null, 28) + .decimal64Column(-8, 50L, 52L, 54L, 56L, 58L, 60L, 62L, 64L, 66L, null) .build()) { splits = t1.contiguousSplit(2, 5, 9); assertEquals(4, splits.length); @@ -1837,6 +1894,8 @@ void testRoundRobinPartition() { .timestampSecondsColumn(1L, null, 3L, 4L, 5L, 6L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, null, 11L, 12L, 13L, 14L, 15L) .column( "A", "B", "C", "D", null, "TESTING", "1", "2", "3", "4", "5", "6", "7", null, "9", "10", "11", "12", "13", null, "15") .column( "A", "A", "C", "C", null, "TESTING", "1", "2", "3", "4", "5", "6", "7", null, "9", "10", "11", "12", "13", null, "15") + .decimal32Column(-3, 100, 202, 3003, 40004, 5, -60, 1, null, 3, null, 5, null, 7, null, 9, null, 11, null, 13, null, 15) + .decimal64Column( -8, 1L, null, 1001L, 50L, -2000L, null, 1L, 2L, 3L, 4L, null, 6L, 7L, 8L, 9L, null, 11L, 12L, 13L, 14L, null) .build()) { try (Table expectedTable = new Table.TestBuilder() .column( 100, 40004, 1, null, 7, null, 13, 202, 5, null, 5, null, 11, null, 3003, -60, 3, null, 9, null, 15) @@ -1851,6 +1910,8 @@ void testRoundRobinPartition() { .timestampSecondsColumn(1L, 4L, 1L, 4L, 7L, null, 13L, null, 5L, 2L, 5L, 8L, 11L, 14L, 3L, 6L, 3L, 6L, 9L, 12L, 15L) .column( "A", "D", "1", "4", "7", "10", "13", "B", null, "2", "5", null, "11", null, "C", "TESTING", "3", "6", "9", "12", "15") .column( "A", "C", "1", "4", "7", "10", "13", "A", null, "2", "5", null, "11", null, "C", "TESTING", "3", "6", "9", "12", "15") + .decimal32Column(-3, 100, 40004, 1, null, 7, null, 13, 202, 5, null, 5, null, 11, null, 3003, -60, 3, null, 9, null, 15) + .decimal64Column(-8, 1L, 50L, 1L, 4L, 7L, null, 13L, null, -2000L, 2L, null, 8L, 11L, 14L, 1001L, null, 3L, 6L, 9L, 12L, null) .build(); PartitionedTable pt = t.roundRobinPartition(3, 0)) { assertTablesAreEqual(expectedTable, pt.getTable()); @@ -1874,6 +1935,8 @@ void testRoundRobinPartition() { .timestampSecondsColumn(3L, 6L, 3L, 6L, 9L, 12L, 15L, 1L, 4L, 1L, 4L, 7L, null, 13L, null, 5L, 2L, 5L, 8L, 11L, 14L) .column( "C", "TESTING", "3", "6", "9", "12", "15", "A", "D", "1", "4", "7", "10", "13", "B", null, "2", "5", null, "11", null) .column( "C", "TESTING", "3", "6", "9", "12", "15", "A", "C", "1", "4", "7", "10", "13", "A", null, "2", "5", null, "11", null) + .decimal32Column(-3, 3003, -60, 3, null, 9, null, 15, 100, 40004, 1, null, 7, null, 13, 202, 5, null, 5, null, 11, null) + .decimal64Column(-8, 1001L, null, 3L, 6L, 9L, 12L, null, 1L, 50L, 1L, 4L, 7L, null, 13L, null, -2000L, 2L, null, 8L, 11L, 14L) .build(); PartitionedTable pt = t.roundRobinPartition(3, 1)) { assertTablesAreEqual(expectedTable, pt.getTable()); @@ -1897,6 +1960,8 @@ void testRoundRobinPartition() { .timestampSecondsColumn(null, 5L, 2L, 5L, 8L, 11L, 14L, 3L, 6L, 3L, 6L, 9L, 12L, 15L, 1L, 4L, 1L, 4L, 7L, null, 13L) .column( "B", null, "2", "5", null, "11", null, "C", "TESTING", "3", "6", "9", "12", "15", "A", "D", "1", "4", "7", "10", "13") .column( "A", null, "2", "5", null, "11", null, "C", "TESTING", "3", "6", "9", "12", "15", "A", "C", "1", "4", "7", "10", "13") + .decimal32Column(-3, 202, 5, null, 5, null, 11, null, 3003, -60, 3, null, 9, null, 15, 100, 40004, 1, null, 7, null, 13) + .decimal64Column(-8, null, -2000L, 2L, null, 8L, 11L, 14L, 1001L, null, 3L, 6L, 9L, 12L, null, 1L, 50L, 1L, 4L, 7L, null, 13L) .build(); PartitionedTable pt = t.roundRobinPartition(3, 2)) { assertTablesAreEqual(expectedTable, pt.getTable()); @@ -1958,6 +2023,9 @@ void testConcatHost() throws IOException { .column( 1, 2, null, 4, 5, 6, 7, 8, 9, 10, null, 12, 13, 14, null, null, 1, 2, null, 4 , 5, 6, 7, 8, 9, 10, null, 12, 13, 14, null, null) + .decimal32Column(-3, + 1, 2, null, 4, 5, 6, 7, 8, 9, 10, null, 12, 13, 14, null, null, + 1, 2, null, 4 , 5, 6, 7, 8, 9, 10, null, 12, 13, 14, null, null) .build(); Table expected = new Table.TestBuilder() .column( @@ -1969,6 +2037,15 @@ void testConcatHost() throws IOException { 1, 2, null, 4 , 5, 6, 7, 8, 9, 10, null, 12, 13, 14, null, null, 1, 2, null, 4 , 5, 6, 7, 8, 9, 10, null, 12, 13, 14, null, null, 1, 2, null, 4 , 5, 6, 7, 8, 9, 10, null, 12, 13, 14, null, null) + .decimal32Column(-3, + null, 12, 13, 14, null, null, + 1, 2, null, 4 , 5, 6, 7, 8, 9, 10, null, 12, 13, 14, null, null, + 1, 2, null, 4 , 5, 6, 7, 8, 9, 10, null, 12, 13, 14, null, null, + 1, 2, null, 4 , 5, 6, 7, 8, 9, 10, null, 12, 13, 14, null, null, + null, 12, 13, 14, null, null, + 1, 2, null, 4 , 5, 6, 7, 8, 9, 10, null, 12, 13, 14, null, null, + 1, 2, null, 4 , 5, 6, 7, 8, 9, 10, null, 12, 13, 14, null, null, + 1, 2, null, 4 , 5, 6, 7, 8, 9, 10, null, 12, 13, 14, null, null) .build(); Table t2 = t1.concatenate(t1, t1)) { ByteArrayOutputStream out = new ByteArrayOutputStream(); @@ -3210,6 +3287,8 @@ void testGroupByNoAggs() { try (Table t1 = new Table.TestBuilder().column( 1, 1, 1, 1, 1, 1) .column( 1, 3, 3, 5, 5, 0) .column( 12, 14, 13, 17, 17, 17) + .decimal32Column(-3, 12, 14, 13, 111, 222, 333) + .decimal64Column(-3, 12L, 14L, 13L, 111L, 222L, 333L) .build()) { try (Table t3 = t1.groupBy(0, 1).aggregate()) { // verify t3 @@ -3223,11 +3302,15 @@ void testSimpleGather() { try (Table testTable = new Table.TestBuilder() .column(1, 2, 3, 4, 5) .column("A", "AA", "AAA", "AAAA", "AAAAA") + .decimal32Column(-3, 1, 2, 3, 4, 5) + .decimal64Column(-8, 100001L, 200002L, 300003L, 400004L, 500005L) .build(); ColumnVector gatherMap = ColumnVector.fromInts(0, 2, 4, -2); Table expected = new Table.TestBuilder() .column(1, 3, 5, 4) .column("A", "AAA", "AAAAA", "AAAA") + .decimal32Column(-3, 1, 3, 5, 4) + .decimal64Column(-8, 100001L, 300003L, 500005L, 400004L) .build(); Table found = testTable.gather(gatherMap)) { assertTablesAreEqual(expected, found); @@ -3674,6 +3757,8 @@ void fixedWidthRowsRoundTrip() { .column(true, false, false, true, false, null) .column(1.0f, 3.5f, 5.9f, 7.1f, 9.8f, null) .column(new Byte[]{2, 3, 4, 5, 9, null}) + .decimal32Column(-3, RoundingMode.UNNECESSARY, 5.0d, 9.5d, 0.9d, 7.23d, 2.8d, null) + .decimal64Column(-8, 3L, 9L, 4L, 2L, 20L, null) .build()) { ColumnVector[] rows = t.convertToRows(); try { @@ -3733,6 +3818,8 @@ private Table buildTestTable() { .timestampDayColumn(99, 100, 101, 102, 103, 104, 1, 2, 3, 4, 5, 6, 7, null, 9, 10, 11, 12, 13, null, 15) .timestampMillisecondsColumn(9L, 1006L, 101L, 5092L, null, 88L, 1L, 2L, 3L, 4L, 5L ,6L, 7L, 8L, null, 10L, 11L, 12L, 13L, 14L, 15L) .timestampSecondsColumn(1L, null, 3L, 4L, 5L, 6L, 1L, 2L, 3L, 4L, 5L ,6L, 7L, 8L, 9L, null, 11L, 12L, 13L, 14L, 15L) + .decimal32Column(-3, 100, 202, 3003, 40004, 5, -60, 1, null, 3, null, 5, null, 7, null, 9, null, 11, null, 13, null, 15) + .decimal64Column(-8, 1L, null, 1001L, 50L, -2000L, null, 1L, 2L, 3L, 4L, null, 6L, 7L, 8L, 9L, null, 11L, 12L, 13L, 14L, null) .column( "A", "B", "C", "D", null, "TESTING", "1", "2", "3", "4", "5", "6", "7", null, "9", "10", "11", "12", "13", null, "15") .column( strings("1", "2", "3"), strings("4"), strings("5"), strings("6, 7"), @@ -3762,4 +3849,22 @@ null, structs(struct("3", "4"), struct("1", "2")), .column( "A", "A", "C", "C", null, "TESTING", "1", "2", "3", "4", "5", "6", "7", null, "9", "10", "11", "12", "13", null, "15") .build(); } + + @Test + void testBuilderWithColumn() { + try (Table t1 = new Table.TestBuilder() + .decimal32Column(-3, 120, -230, null, 340) + .decimal64Column(-8, 1000L, 200L, null, 30L).build()) { + try (Table t2 = new Table.TestBuilder() + .decimal32Column(-3, RoundingMode.UNNECESSARY, 0.12, -0.23, null, 0.34) + .decimal64Column(-8, RoundingMode.UNNECESSARY, 1e-5, 2e-6, null, 3e-7).build()) { + try (Table t3 = new Table.TestBuilder() + .decimal32Column(-3, RoundingMode.UNNECESSARY, "0.12", "-000.23", null, ".34") + .decimal64Column(-8, RoundingMode.UNNECESSARY, "1e-5", "2e-6", null, "3e-7").build()) { + assertTablesAreEqual(t1, t2); + assertTablesAreEqual(t1, t3); + } + } + } + } }