From 3aecce25701ea4d0f182d2a0f47237863ad15e69 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Wed, 19 Jan 2022 08:44:49 -0600 Subject: [PATCH] Update Java tests to expect DECIMAL128 from Arrow (#10073) After #9986 reading Arrow in libcudf now returns DECIMAL128 instead of DECIMAL64. This updates the Java tests to expect DECIMAL128 instead of DECIMAL64 by upcasting the decimal columns in the original table being round-tripped through Arrow before comparing the result. Authors: - Jason Lowe (https://github.com/jlowe) Approvers: - Rong Ou (https://github.com/rongou) - MithunR (https://github.com/mythrocks) - Nghia Truong (https://github.com/ttnghia) URL: https://github.com/rapidsai/cudf/pull/10073 --- .../test/java/ai/rapids/cudf/TableTest.java | 70 +++++++++++++++++-- 1 file changed, 66 insertions(+), 4 deletions(-) diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index 7fe69d2d7fc..18a0de77664 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -7192,6 +7192,64 @@ void testParquetWriteToFileUncompressedNoStats() throws IOException { } } + /** Return a column where DECIMAL64 has been up-casted to DECIMAL128 */ + private ColumnVector castDecimal64To128(ColumnView c) { + DType dtype = c.getType(); + switch (dtype.getTypeId()) { + case DECIMAL64: + return c.castTo(DType.create(DType.DTypeEnum.DECIMAL128, dtype.getScale())); + case STRUCT: + case LIST: + { + ColumnView[] oldViews = c.getChildColumnViews(); + assert oldViews != null; + ColumnVector[] newChildren = new ColumnVector[oldViews.length]; + try { + for (int i = 0; i < oldViews.length; i++) { + newChildren[i] = castDecimal64To128(oldViews[i]); + } + try (ColumnView newView = new ColumnView(dtype, c.getRowCount(), + Optional.of(c.getNullCount()), c.getValid(), c.getOffsets(), newChildren)) { + return newView.copyToColumnVector(); + } + } finally { + for (ColumnView v : oldViews) { + v.close(); + } + for (ColumnVector v : newChildren) { + if (v != null) { + v.close(); + } + } + } + } + default: + if (c instanceof ColumnVector) { + return ((ColumnVector) c).incRefCount(); + } else { + return c.copyToColumnVector(); + } + } + } + + /** Return a new Table with any DECIMAL64 columns up-casted to DECIMAL128 */ + private Table castDecimal64To128(Table t) { + final int numCols = t.getNumberOfColumns(); + ColumnVector[] cols = new ColumnVector[numCols]; + try { + for (int i = 0; i < numCols; i++) { + cols[i] = castDecimal64To128(t.getColumn(i)); + } + return new Table(cols); + } finally { + for (ColumnVector c : cols) { + if (c != null) { + c.close(); + } + } + } + } + @Test void testArrowIPCWriteToFileWithNamesAndMetadata() throws IOException { File tempFile = File.createTempFile("test-names-metadata", ".arrow"); @@ -7203,7 +7261,9 @@ void testArrowIPCWriteToFileWithNamesAndMetadata() throws IOException { try (TableWriter writer = Table.writeArrowIPCChunked(options, tempFile.getAbsoluteFile())) { writer.write(table0); } - try (StreamedTableReader reader = Table.readArrowIPCChunked(tempFile)) { + // Reading from Arrow converts decimals to DECIMAL128 + try (StreamedTableReader reader = Table.readArrowIPCChunked(tempFile); + Table expected = castDecimal64To128(table0)) { boolean done = false; int count = 0; while (!done) { @@ -7211,7 +7271,7 @@ void testArrowIPCWriteToFileWithNamesAndMetadata() throws IOException { if (t == null) { done = true; } else { - assertTablesAreEqual(table0, t); + assertTablesAreEqual(expected, t); count++; } } @@ -7243,7 +7303,9 @@ void testArrowIPCWriteToBufferChunked() { writer.write(table0); writer.write(table0); } - try (StreamedTableReader reader = Table.readArrowIPCChunked(new MyBufferProvider(consumer))) { + // Reading from Arrow converts decimals to DECIMAL128 + try (StreamedTableReader reader = Table.readArrowIPCChunked(new MyBufferProvider(consumer)); + Table expected = castDecimal64To128(table0)) { boolean done = false; int count = 0; while (!done) { @@ -7251,7 +7313,7 @@ void testArrowIPCWriteToBufferChunked() { if (t == null) { done = true; } else { - assertTablesAreEqual(table0, t); + assertTablesAreEqual(expected, t); count++; } }