diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index 7fe69d2d7fc..18a0de77664 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -7192,6 +7192,64 @@ void testParquetWriteToFileUncompressedNoStats() throws IOException { } } + /** Return a column where DECIMAL64 has been up-casted to DECIMAL128 */ + private ColumnVector castDecimal64To128(ColumnView c) { + DType dtype = c.getType(); + switch (dtype.getTypeId()) { + case DECIMAL64: + return c.castTo(DType.create(DType.DTypeEnum.DECIMAL128, dtype.getScale())); + case STRUCT: + case LIST: + { + ColumnView[] oldViews = c.getChildColumnViews(); + assert oldViews != null; + ColumnVector[] newChildren = new ColumnVector[oldViews.length]; + try { + for (int i = 0; i < oldViews.length; i++) { + newChildren[i] = castDecimal64To128(oldViews[i]); + } + try (ColumnView newView = new ColumnView(dtype, c.getRowCount(), + Optional.of(c.getNullCount()), c.getValid(), c.getOffsets(), newChildren)) { + return newView.copyToColumnVector(); + } + } finally { + for (ColumnView v : oldViews) { + v.close(); + } + for (ColumnVector v : newChildren) { + if (v != null) { + v.close(); + } + } + } + } + default: + if (c instanceof ColumnVector) { + return ((ColumnVector) c).incRefCount(); + } else { + return c.copyToColumnVector(); + } + } + } + + /** Return a new Table with any DECIMAL64 columns up-casted to DECIMAL128 */ + private Table castDecimal64To128(Table t) { + final int numCols = t.getNumberOfColumns(); + ColumnVector[] cols = new ColumnVector[numCols]; + try { + for (int i = 0; i < numCols; i++) { + cols[i] = castDecimal64To128(t.getColumn(i)); + } + return new Table(cols); + } finally { + for (ColumnVector c : cols) { + if (c != null) { + c.close(); + } + } + } + } + @Test void testArrowIPCWriteToFileWithNamesAndMetadata() throws IOException { File tempFile = File.createTempFile("test-names-metadata", ".arrow"); @@ -7203,7 +7261,9 @@ void testArrowIPCWriteToFileWithNamesAndMetadata() throws IOException { try (TableWriter writer = Table.writeArrowIPCChunked(options, tempFile.getAbsoluteFile())) { writer.write(table0); } - try (StreamedTableReader reader = Table.readArrowIPCChunked(tempFile)) { + // Reading from Arrow converts decimals to DECIMAL128 + try (StreamedTableReader reader = Table.readArrowIPCChunked(tempFile); + Table expected = castDecimal64To128(table0)) { boolean done = false; int count = 0; while (!done) { @@ -7211,7 +7271,7 @@ void testArrowIPCWriteToFileWithNamesAndMetadata() throws IOException { if (t == null) { done = true; } else { - assertTablesAreEqual(table0, t); + assertTablesAreEqual(expected, t); count++; } } @@ -7243,7 +7303,9 @@ void testArrowIPCWriteToBufferChunked() { writer.write(table0); writer.write(table0); } - try (StreamedTableReader reader = Table.readArrowIPCChunked(new MyBufferProvider(consumer))) { + // Reading from Arrow converts decimals to DECIMAL128 + try (StreamedTableReader reader = Table.readArrowIPCChunked(new MyBufferProvider(consumer)); + Table expected = castDecimal64To128(table0)) { boolean done = false; int count = 0; while (!done) { @@ -7251,7 +7313,7 @@ void testArrowIPCWriteToBufferChunked() { if (t == null) { done = true; } else { - assertTablesAreEqual(table0, t); + assertTablesAreEqual(expected, t); count++; } }