From 64427a4c5f59d318d4b2702e943a167bb58f9a35 Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Fri, 15 Apr 2022 14:31:44 +0800 Subject: [PATCH] Fix and add a test case --- java/src/main/native/src/TableJni.cpp | 13 ++-- .../test/java/ai/rapids/cudf/TableTest.java | 67 +++++++++++++++++++ 2 files changed, 74 insertions(+), 6 deletions(-) diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp index d5cd07c9826..919958d4db2 100644 --- a/java/src/main/native/src/TableJni.cpp +++ b/java/src/main/native/src/TableJni.cpp @@ -673,6 +673,7 @@ int set_column_metadata(cudf::io::column_in_metadata &column_metadata, cudf::jni::native_jbooleanArray &is_int96, cudf::jni::native_jintArray &precisions, cudf::jni::native_jbooleanArray &is_map, + cudf::jni::native_jbooleanArray &hasParquetFieldIds, cudf::jni::native_jintArray &parquetFieldIds, cudf::jni::native_jintArray &children, int num_children, int read_index) { int write_index = 0; @@ -688,15 +689,15 @@ int set_column_metadata(cudf::io::column_in_metadata &column_metadata, if (is_map[read_index]) { child.set_list_column_as_map(); } - if (!parquetFieldIds.is_null()) { + if (!parquetFieldIds.is_null() && hasParquetFieldIds[read_index]) { child.set_parquet_field_id(parquetFieldIds[read_index]); } column_metadata.add_child(child); int childs_children = children[read_index++]; if (childs_children > 0) { read_index = set_column_metadata(column_metadata.child(write_index), col_names, nullability, - is_int96, precisions, is_map, parquetFieldIds, children, - childs_children, read_index); + is_int96, precisions, is_map, hasParquetFieldIds, + parquetFieldIds, children, childs_children, read_index); } } return read_index; @@ -741,9 +742,9 @@ void createTableMetaData(JNIEnv *env, jint num_children, jobjectArray &j_col_nam } int childs_children = children[read_index++]; if (childs_children > 0) { - read_index = set_column_metadata(metadata.column_metadata[write_index], cpp_names, - col_nullability, is_int96, precisions, is_map, - parquetFieldIds, children, childs_children, read_index); + read_index = set_column_metadata( + metadata.column_metadata[write_index], cpp_names, col_nullability, is_int96, precisions, + is_map, hasParquetFieldIds, parquetFieldIds, children, childs_children, read_index); } } } diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index 120e30fbe63..af28cfb6d6c 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -7953,6 +7953,73 @@ void testParquetWriteWithFieldId() throws IOException { } } + @Test + void testParquetWriteWithFieldIdNestNotSpecified() throws IOException { + // field IDs are: + // c0: no field ID + // c1: 1 + // c2: no field ID + // c21: 21 + // c22: no field ID + // c3: 3 + // c31: 31 + // c32: no field ID + // c4: 0 + ColumnWriterOptions.StructBuilder c2Builder = + structBuilder("c2", true) + .withColumn(true, "c21", 21) + .withColumns(true, "c22"); + ColumnWriterOptions.StructBuilder c3Builder = + structBuilder("c3", true, 3) + .withColumn(true, "c31", 31) + .withColumns(true, "c32"); + ParquetWriterOptions options = ParquetWriterOptions.builder() + .withColumns(true, "c0") + .withDecimalColumn("c1", 9, true, 1) + .withStructColumn(c2Builder.build()) + .withStructColumn(c3Builder.build()) + .withColumn(true, "c4", 0) + .build(); + + File tempFile = File.createTempFile("test-field-id", ".parquet"); + try { + HostColumnVector.StructType structType = new HostColumnVector.StructType( + true, + new HostColumnVector.BasicType(true, DType.STRING), + new HostColumnVector.BasicType(true, DType.STRING)); + + try (Table table0 = new Table.TestBuilder() + .column(true, false) // c0 + .decimal32Column(0, 298, 2473) // c1 + .column(structType, // c2 + new HostColumnVector.StructData("a", "b"), new HostColumnVector.StructData("a", "b")) + .column(structType, // c3 + new HostColumnVector.StructData("a", "b"), new HostColumnVector.StructData("a", "b")) + .column("a", "b") // c4 + .build()) { + try (TableWriter writer = Table.writeParquetChunked(options, tempFile.getAbsoluteFile())) { + writer.write(table0); + } + } + + try (ParquetFileReader reader = ParquetFileReader.open(HadoopInputFile.fromPath( + new Path(tempFile.getAbsolutePath()), + new Configuration()))) { + MessageType schema = reader.getFooter().getFileMetaData().getSchema(); + assert (schema.getFields().get(0).getId() == null); + assert (schema.getFields().get(1).getId().intValue() == 1); + assert (schema.getFields().get(2).getId() == null); + assert (((GroupType) schema.getFields().get(2)).getFields().get(0).getId().intValue() == 21); + assert (((GroupType) schema.getFields().get(2)).getFields().get(1).getId() == null); + assert (((GroupType) schema.getFields().get(3)).getFields().get(0).getId().intValue() == 31); + assert (((GroupType) schema.getFields().get(3)).getFields().get(1).getId() == null); + assert (schema.getFields().get(4).getId().intValue() == 0); + } + } finally { + tempFile.delete(); + } + } + /** Return a column where DECIMAL64 has been up-casted to DECIMAL128 */ private ColumnVector castDecimal64To128(ColumnView c) { DType dtype = c.getType();