Skip to content

Commit

Permalink
Update Java tests to expect DECIMAL128 from Arrow (#10073)
Browse files Browse the repository at this point in the history
After #9986 reading Arrow in libcudf now returns DECIMAL128 instead of DECIMAL64.  This updates the Java tests to expect DECIMAL128 instead of DECIMAL64 by upcasting the decimal columns in the original table being round-tripped through Arrow before comparing the result.

Authors:
  - Jason Lowe (https://github.com/jlowe)

Approvers:
  - Rong Ou (https://github.com/rongou)
  - MithunR (https://github.com/mythrocks)
  - Nghia Truong (https://github.com/ttnghia)

URL: #10073
  • Loading branch information
jlowe authored Jan 19, 2022
1 parent e416188 commit 3aecce2
Showing 1 changed file with 66 additions and 4 deletions.
70 changes: 66 additions & 4 deletions java/src/test/java/ai/rapids/cudf/TableTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -7192,6 +7192,64 @@ void testParquetWriteToFileUncompressedNoStats() throws IOException {
}
}

/** Return a column where DECIMAL64 has been up-casted to DECIMAL128 */
private ColumnVector castDecimal64To128(ColumnView c) {
DType dtype = c.getType();
switch (dtype.getTypeId()) {
case DECIMAL64:
return c.castTo(DType.create(DType.DTypeEnum.DECIMAL128, dtype.getScale()));
case STRUCT:
case LIST:
{
ColumnView[] oldViews = c.getChildColumnViews();
assert oldViews != null;
ColumnVector[] newChildren = new ColumnVector[oldViews.length];
try {
for (int i = 0; i < oldViews.length; i++) {
newChildren[i] = castDecimal64To128(oldViews[i]);
}
try (ColumnView newView = new ColumnView(dtype, c.getRowCount(),
Optional.of(c.getNullCount()), c.getValid(), c.getOffsets(), newChildren)) {
return newView.copyToColumnVector();
}
} finally {
for (ColumnView v : oldViews) {
v.close();
}
for (ColumnVector v : newChildren) {
if (v != null) {
v.close();
}
}
}
}
default:
if (c instanceof ColumnVector) {
return ((ColumnVector) c).incRefCount();
} else {
return c.copyToColumnVector();
}
}
}

/** Return a new Table with any DECIMAL64 columns up-casted to DECIMAL128 */
private Table castDecimal64To128(Table t) {
final int numCols = t.getNumberOfColumns();
ColumnVector[] cols = new ColumnVector[numCols];
try {
for (int i = 0; i < numCols; i++) {
cols[i] = castDecimal64To128(t.getColumn(i));
}
return new Table(cols);
} finally {
for (ColumnVector c : cols) {
if (c != null) {
c.close();
}
}
}
}

@Test
void testArrowIPCWriteToFileWithNamesAndMetadata() throws IOException {
File tempFile = File.createTempFile("test-names-metadata", ".arrow");
Expand All @@ -7203,15 +7261,17 @@ void testArrowIPCWriteToFileWithNamesAndMetadata() throws IOException {
try (TableWriter writer = Table.writeArrowIPCChunked(options, tempFile.getAbsoluteFile())) {
writer.write(table0);
}
try (StreamedTableReader reader = Table.readArrowIPCChunked(tempFile)) {
// Reading from Arrow converts decimals to DECIMAL128
try (StreamedTableReader reader = Table.readArrowIPCChunked(tempFile);
Table expected = castDecimal64To128(table0)) {
boolean done = false;
int count = 0;
while (!done) {
try (Table t = reader.getNextIfAvailable()) {
if (t == null) {
done = true;
} else {
assertTablesAreEqual(table0, t);
assertTablesAreEqual(expected, t);
count++;
}
}
Expand Down Expand Up @@ -7243,15 +7303,17 @@ void testArrowIPCWriteToBufferChunked() {
writer.write(table0);
writer.write(table0);
}
try (StreamedTableReader reader = Table.readArrowIPCChunked(new MyBufferProvider(consumer))) {
// Reading from Arrow converts decimals to DECIMAL128
try (StreamedTableReader reader = Table.readArrowIPCChunked(new MyBufferProvider(consumer));
Table expected = castDecimal64To128(table0)) {
boolean done = false;
int count = 0;
while (!done) {
try (Table t = reader.getNextIfAvailable()) {
if (t == null) {
done = true;
} else {
assertTablesAreEqual(table0, t);
assertTablesAreEqual(expected, t);
count++;
}
}
Expand Down

0 comments on commit 3aecce2

Please sign in to comment.