diff --git a/CHANGELOG.md b/CHANGELOG.md index 4b88b350725..fd3a045c9b9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -127,6 +127,7 @@ - PR #6837 Avoid gather when copying strings view from start of strings column - PR #6859 Move align_ptr_for_type() from cuda.cuh to alignment.hpp - PR #6807 Refactor `std::array` usage in row group index writing in ORC +- PR #6908 Parquet option for strictly decimal reading ## Bug Fixes diff --git a/java/src/main/java/ai/rapids/cudf/ParquetOptions.java b/java/src/main/java/ai/rapids/cudf/ParquetOptions.java index 4ef1b713531..d8bb6b45f88 100644 --- a/java/src/main/java/ai/rapids/cudf/ParquetOptions.java +++ b/java/src/main/java/ai/rapids/cudf/ParquetOptions.java @@ -27,22 +27,30 @@ public class ParquetOptions extends ColumnFilterOptions { private final DType unit; + private final boolean strictDecimalType; + private ParquetOptions(Builder builder) { super(builder); unit = builder.unit; + strictDecimalType = builder.strictDecimalType; } DType timeUnit() { return unit; } + boolean isStrictDecimalType() { + return strictDecimalType; + } + public static Builder builder() { return new Builder(); } public static class Builder extends ColumnFilterOptions.Builder { private DType unit = DType.EMPTY; + private boolean strictDecimalType = false; /** * Specify the time unit to use when returning timestamps. @@ -55,6 +63,16 @@ public Builder withTimeUnit(DType unit) { return this; } + /** + * Specify how to deal with decimal columns who are not backed by INT32/64 while reading. + * @param strictDecimalType whether strictly reading all decimal columns as fixed-point decimal type + * @return builder for chaining + */ + public Builder enableStrictDecimalType(boolean strictDecimalType) { + this.strictDecimalType = strictDecimalType; + return this; + } + public ParquetOptions build() { return new ParquetOptions(this); } diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index b7841c33a79..66552acc5bc 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -219,15 +219,17 @@ private static native long[] readCSV(String[] columnNames, String[] dTypes, /** * Read in Parquet formatted data. - * @param filterColumnNames name of the columns to read, or an empty array if we want to read - * all of them - * @param filePath the path of the file to read, or null if no path should be read. - * @param address the address of the buffer to read from or 0 if we should not. - * @param length the length of the buffer to read from. - * @param timeUnit return type of TimeStamp in units + * @param filterColumnNames name of the columns to read, or an empty array if we want to read + * all of them + * @param filePath the path of the file to read, or null if no path should be read. + * @param address the address of the buffer to read from or 0 if we should not. + * @param length the length of the buffer to read from. + * @param timeUnit return type of TimeStamp in units + * @param strictDecimalTypes whether strictly reading all decimal columns as fixed-point decimal type */ private static native long[] readParquet(String[] filterColumnNames, String filePath, - long address, long length, int timeUnit) throws CudfException; + long address, long length, int timeUnit, + boolean strictDecimalTypes) throws CudfException; /** * Setup everything to write parquet formatted data to a file. @@ -618,7 +620,8 @@ public static Table readParquet(File path) { */ public static Table readParquet(ParquetOptions opts, File path) { return new Table(readParquet(opts.getIncludeColumnNames(), - path.getAbsolutePath(), 0, 0, opts.timeUnit().typeId.getNativeId())); + path.getAbsolutePath(), 0, 0, opts.timeUnit().typeId.getNativeId(), + opts.isStrictDecimalType())); } /** @@ -678,7 +681,8 @@ public static Table readParquet(ParquetOptions opts, HostMemoryBuffer buffer, assert len <= buffer.getLength() - offset; assert offset >= 0 && offset < buffer.length; return new Table(readParquet(opts.getIncludeColumnNames(), - null, buffer.getAddress() + offset, len, opts.timeUnit().typeId.getNativeId())); + null, buffer.getAddress() + offset, len, opts.timeUnit().typeId.getNativeId(), + opts.isStrictDecimalType())); } /** diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp index 4dc32307552..3ec1f5e3c94 100644 --- a/java/src/main/native/src/TableJni.cpp +++ b/java/src/main/native/src/TableJni.cpp @@ -787,7 +787,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_readCSV( JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_readParquet( JNIEnv *env, jclass j_class_object, jobjectArray filter_col_names, jstring inputfilepath, - jlong buffer, jlong buffer_length, jint unit) { + jlong buffer, jlong buffer_length, jint unit, jboolean strict_decimal_types) { bool read_buffer = true; if (buffer == 0) { JNI_NULL_CHECK(env, inputfilepath, "input file or buffer must be supplied", NULL); @@ -823,6 +823,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_readParquet( .convert_strings_to_categories(false) .timestamp_type(cudf::data_type(static_cast(unit))) .build(); + opts.set_strict_decimal_types(static_cast(strict_decimal_types)); cudf::io::table_with_metadata result = cudf::io::read_parquet(opts); return cudf::jni::convert_table_for_return(env, result.tbl); } diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index edacfc37cc6..c2879d09b54 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -67,6 +67,7 @@ public class TableTest extends CudfTestBase { private static final File TEST_ORC_FILE = new File("src/test/resources/TestOrcFile.orc"); private static final File TEST_ORC_TIMESTAMP_DATE_FILE = new File( "src/test/resources/timestamp-date-test.orc"); + private static final File TEST_DECIMAL_PARQUET_FILE = new File("src/test/resources/decimal.parquet"); private static final Schema CSV_DATA_BUFFER_SCHEMA = Schema.builder() .column(DType.INT32, "A") @@ -717,6 +718,30 @@ void testReadParquetFull() { } } + @Test + void testReadParquetContainsDecimalData() { + try (Table table = Table.readParquet(TEST_DECIMAL_PARQUET_FILE)) { + long rows = table.getRowCount(); + assertEquals(100, rows); + DType[] expectedTypes = new DType[]{ + DType.create(DType.DTypeEnum.DECIMAL64, 0), // Decimal(18, 0) + DType.create(DType.DTypeEnum.DECIMAL32, -3), // Decimal(7, 3) + DType.create(DType.DTypeEnum.DECIMAL64, -10), // Decimal(10, 10) + DType.create(DType.DTypeEnum.DECIMAL32, 0), // Decimal(1, 0) + DType.create(DType.DTypeEnum.DECIMAL64, -15), // Decimal(18, 15) + DType.FLOAT64, // Decimal(20, 10) which is backed by FIXED_LEN_BYTE_ARRAY + DType.INT64, + DType.FLOAT32 + }; + assertTableTypes(expectedTypes, table); + } + // An CudfException will be thrown here because we haven't support reading decimal stored as FIXED_LEN_BYTE_ARRAY. + ParquetOptions opts = ParquetOptions.builder().enableStrictDecimalType(true).build(); + assertThrows(ai.rapids.cudf.CudfException.class, () -> { + try (Table table = Table.readParquet(opts, TEST_DECIMAL_PARQUET_FILE)) {} + }); + } + @Test void testReadORC() { ORCOptions opts = ORCOptions.builder() diff --git a/java/src/test/resources/decimal.parquet b/java/src/test/resources/decimal.parquet new file mode 100644 index 00000000000..cc0da9ed948 Binary files /dev/null and b/java/src/test/resources/decimal.parquet differ