Skip to content

Commit

Permalink
Parquet option for strictly decimal reading (#6908)
Browse files Browse the repository at this point in the history
This PR is about to add a parquet option to determine whether strictly reading all decimal columns as fixed-point decimal types or converting decimal column who are not backed by int32/64 to float64.
  • Loading branch information
sperlingxx authored Dec 4, 2020
1 parent e22c3ae commit cd7a0ad
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 10 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@
- PR #6837 Avoid gather when copying strings view from start of strings column
- PR #6859 Move align_ptr_for_type() from cuda.cuh to alignment.hpp
- PR #6807 Refactor `std::array` usage in row group index writing in ORC
- PR #6908 Parquet option for strictly decimal reading

## Bug Fixes

Expand Down
18 changes: 18 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ParquetOptions.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,30 @@ public class ParquetOptions extends ColumnFilterOptions {

private final DType unit;

private final boolean strictDecimalType;


private ParquetOptions(Builder builder) {
super(builder);
unit = builder.unit;
strictDecimalType = builder.strictDecimalType;
}

DType timeUnit() {
return unit;
}

boolean isStrictDecimalType() {
return strictDecimalType;
}

public static Builder builder() {
return new Builder();
}

public static class Builder extends ColumnFilterOptions.Builder<Builder> {
private DType unit = DType.EMPTY;
private boolean strictDecimalType = false;

/**
* Specify the time unit to use when returning timestamps.
Expand All @@ -55,6 +63,16 @@ public Builder withTimeUnit(DType unit) {
return this;
}

/**
* Specify how to deal with decimal columns who are not backed by INT32/64 while reading.
* @param strictDecimalType whether strictly reading all decimal columns as fixed-point decimal type
* @return builder for chaining
*/
public Builder enableStrictDecimalType(boolean strictDecimalType) {
this.strictDecimalType = strictDecimalType;
return this;
}

public ParquetOptions build() {
return new ParquetOptions(this);
}
Expand Down
22 changes: 13 additions & 9 deletions java/src/main/java/ai/rapids/cudf/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -219,15 +219,17 @@ private static native long[] readCSV(String[] columnNames, String[] dTypes,

/**
* Read in Parquet formatted data.
* @param filterColumnNames name of the columns to read, or an empty array if we want to read
* all of them
* @param filePath the path of the file to read, or null if no path should be read.
* @param address the address of the buffer to read from or 0 if we should not.
* @param length the length of the buffer to read from.
* @param timeUnit return type of TimeStamp in units
* @param filterColumnNames name of the columns to read, or an empty array if we want to read
* all of them
* @param filePath the path of the file to read, or null if no path should be read.
* @param address the address of the buffer to read from or 0 if we should not.
* @param length the length of the buffer to read from.
* @param timeUnit return type of TimeStamp in units
* @param strictDecimalTypes whether strictly reading all decimal columns as fixed-point decimal type
*/
private static native long[] readParquet(String[] filterColumnNames, String filePath,
long address, long length, int timeUnit) throws CudfException;
long address, long length, int timeUnit,
boolean strictDecimalTypes) throws CudfException;

/**
* Setup everything to write parquet formatted data to a file.
Expand Down Expand Up @@ -618,7 +620,8 @@ public static Table readParquet(File path) {
*/
public static Table readParquet(ParquetOptions opts, File path) {
return new Table(readParquet(opts.getIncludeColumnNames(),
path.getAbsolutePath(), 0, 0, opts.timeUnit().typeId.getNativeId()));
path.getAbsolutePath(), 0, 0, opts.timeUnit().typeId.getNativeId(),
opts.isStrictDecimalType()));
}

/**
Expand Down Expand Up @@ -678,7 +681,8 @@ public static Table readParquet(ParquetOptions opts, HostMemoryBuffer buffer,
assert len <= buffer.getLength() - offset;
assert offset >= 0 && offset < buffer.length;
return new Table(readParquet(opts.getIncludeColumnNames(),
null, buffer.getAddress() + offset, len, opts.timeUnit().typeId.getNativeId()));
null, buffer.getAddress() + offset, len, opts.timeUnit().typeId.getNativeId(),
opts.isStrictDecimalType()));
}

/**
Expand Down
3 changes: 2 additions & 1 deletion java/src/main/native/src/TableJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_readCSV(

JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_readParquet(
JNIEnv *env, jclass j_class_object, jobjectArray filter_col_names, jstring inputfilepath,
jlong buffer, jlong buffer_length, jint unit) {
jlong buffer, jlong buffer_length, jint unit, jboolean strict_decimal_types) {
bool read_buffer = true;
if (buffer == 0) {
JNI_NULL_CHECK(env, inputfilepath, "input file or buffer must be supplied", NULL);
Expand Down Expand Up @@ -823,6 +823,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_readParquet(
.convert_strings_to_categories(false)
.timestamp_type(cudf::data_type(static_cast<cudf::type_id>(unit)))
.build();
opts.set_strict_decimal_types(static_cast<bool>(strict_decimal_types));
cudf::io::table_with_metadata result = cudf::io::read_parquet(opts);
return cudf::jni::convert_table_for_return(env, result.tbl);
}
Expand Down
25 changes: 25 additions & 0 deletions java/src/test/java/ai/rapids/cudf/TableTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ public class TableTest extends CudfTestBase {
private static final File TEST_ORC_FILE = new File("src/test/resources/TestOrcFile.orc");
private static final File TEST_ORC_TIMESTAMP_DATE_FILE = new File(
"src/test/resources/timestamp-date-test.orc");
private static final File TEST_DECIMAL_PARQUET_FILE = new File("src/test/resources/decimal.parquet");

private static final Schema CSV_DATA_BUFFER_SCHEMA = Schema.builder()
.column(DType.INT32, "A")
Expand Down Expand Up @@ -717,6 +718,30 @@ void testReadParquetFull() {
}
}

@Test
void testReadParquetContainsDecimalData() {
try (Table table = Table.readParquet(TEST_DECIMAL_PARQUET_FILE)) {
long rows = table.getRowCount();
assertEquals(100, rows);
DType[] expectedTypes = new DType[]{
DType.create(DType.DTypeEnum.DECIMAL64, 0), // Decimal(18, 0)
DType.create(DType.DTypeEnum.DECIMAL32, -3), // Decimal(7, 3)
DType.create(DType.DTypeEnum.DECIMAL64, -10), // Decimal(10, 10)
DType.create(DType.DTypeEnum.DECIMAL32, 0), // Decimal(1, 0)
DType.create(DType.DTypeEnum.DECIMAL64, -15), // Decimal(18, 15)
DType.FLOAT64, // Decimal(20, 10) which is backed by FIXED_LEN_BYTE_ARRAY
DType.INT64,
DType.FLOAT32
};
assertTableTypes(expectedTypes, table);
}
// An CudfException will be thrown here because we haven't support reading decimal stored as FIXED_LEN_BYTE_ARRAY.
ParquetOptions opts = ParquetOptions.builder().enableStrictDecimalType(true).build();
assertThrows(ai.rapids.cudf.CudfException.class, () -> {
try (Table table = Table.readParquet(opts, TEST_DECIMAL_PARQUET_FILE)) {}
});
}

@Test
void testReadORC() {
ORCOptions opts = ORCOptions.builder()
Expand Down
Binary file added java/src/test/resources/decimal.parquet
Binary file not shown.

0 comments on commit cd7a0ad

Please sign in to comment.