diff --git a/java/src/main/java/ai/rapids/cudf/ArrowIPCWriterOptions.java b/java/src/main/java/ai/rapids/cudf/ArrowIPCWriterOptions.java index 298e99b059d..ee5ae094b29 100644 --- a/java/src/main/java/ai/rapids/cudf/ArrowIPCWriterOptions.java +++ b/java/src/main/java/ai/rapids/cudf/ArrowIPCWriterOptions.java @@ -67,6 +67,64 @@ public Builder withCallback(DoneOnGpu callback) { return this; } + /** + * Add the name(s) for nullable column(s). + * + * Please note the column names of the nested struct columns should be flattened in sequence. + * For examples, + *
+     *   A table with an int column and a struct column:
+     *                   ["int_col", "struct_col":{"field_1", "field_2"}]
+     *   output:
+     *                   ["int_col", "struct_col", "field_1", "field_2"]
+     *
+     *   A table with an int column and a list of non-nested type column:
+     *                   ["int_col", "list_col":[]]
+     *   output:
+     *                   ["int_col", "list_col"]
+     *
+     *   A table with an int column and a list of struct column:
+     *                   ["int_col", "list_struct_col":[{"field_1", "field_2"}]]
+     *   output:
+     *                   ["int_col", "list_struct_col", "field_1", "field_2"]
+     * 
+ * + * @param columnNames The column names corresponding to the written table(s). + */ + @Override + public Builder withColumnNames(String... columnNames) { + return super.withColumnNames(columnNames); + } + + /** + * Add the name(s) for non-nullable column(s). + * + * Please note the column names of the nested struct columns should be flattened in sequence. + * For examples, + *
+     *   A table with an int column and a struct column:
+     *                   ["int_col", "struct_col":{"field_1", "field_2"}]
+     *   output:
+     *                   ["int_col", "struct_col", "field_1", "field_2"]
+     *
+     *   A table with an int column and a list of non-nested type column:
+     *                   ["int_col", "list_col":[]]
+     *   output:
+     *                   ["int_col", "list_col"]
+     *
+     *   A table with an int column and a list of struct column:
+     *                   ["int_col", "list_struct_col":[{"field_1", "field_2"}]]
+     *   output:
+     *                   ["int_col", "list_struct_col", "field_1", "field_2"]
+     * 
+ * + * @param columnNames The column names corresponding to the written table(s). + */ + @Override + public Builder withNotNullableColumnNames(String... columnNames) { + return super.withNotNullableColumnNames(columnNames); + } + public ArrowIPCWriterOptions build() { return new ArrowIPCWriterOptions(this); } diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp index e051f68be4e..43616ea413d 100644 --- a/java/src/main/native/src/TableJni.cpp +++ b/java/src/main/native/src/TableJni.cpp @@ -211,12 +211,15 @@ class native_arrow_ipc_writer_handle final { const std::shared_ptr &sink) : initialized(false), column_names(col_names), file_name(""), sink(sink) {} +private: bool initialized; std::vector column_names; + std::vector columns_meta; std::string file_name; std::shared_ptr sink; std::shared_ptr writer; +public: void write(std::shared_ptr &arrow_tab, int64_t max_chunk) { if (!initialized) { if (!sink) { @@ -245,6 +248,59 @@ class native_arrow_ipc_writer_handle final { } initialized = false; } + + std::vector get_column_metadata(const cudf::table_view& tview) { + if (!column_names.empty() && columns_meta.empty()) { + // Rebuild the structure of column meta according to table schema. + // All the tables written by this writer should share the same schema, + // so build column metadata only once. + columns_meta.reserve(tview.num_columns()); + size_t idx = 0; + for (auto itr = tview.begin(); itr < tview.end(); ++itr) { + // It should consume the column names only when a column is + // - type of struct, or + // - not a child. + columns_meta.push_back(build_one_column_meta(*itr, idx)); + } + if (idx < column_names.size()) { + throw cudf::jni::jni_exception("Too many column names are provided."); + } + } + return columns_meta; + } + +private: + cudf::column_metadata build_one_column_meta(const cudf::column_view& cview, size_t& idx, + const bool consume_name = true) { + auto col_meta = cudf::column_metadata{}; + if (consume_name) { + col_meta.name = get_column_name(idx++); + } + // Process children + if (cview.type().id() == cudf::type_id::LIST) { + // list type: + // - requires a stub metadata for offset column(index: 0). + // - does not require a name for the child column(index 1). + col_meta.children_meta = {{}, build_one_column_meta(cview.child(1), idx, false)}; + } else if (cview.type().id() == cudf::type_id::STRUCT) { + // struct type always consumes the column names. + col_meta.children_meta.reserve(cview.num_children()); + for (auto itr = cview.child_begin(); itr < cview.child_end(); ++itr) { + col_meta.children_meta.push_back(build_one_column_meta(*itr, idx)); + } + } else if (cview.type().id() == cudf::type_id::DICTIONARY32) { + // not supported yet in JNI, nested type? + throw cudf::jni::jni_exception("Unsupported type 'DICTIONARY32'"); + } + return col_meta; + } + + std::string& get_column_name(const size_t idx) { + if (idx < 0 || idx >= column_names.size()) { + throw cudf::jni::jni_exception("Missing names for columns or nested struct columns"); + } + return column_names[idx]; + } }; class jni_arrow_output_stream final : public arrow::io::OutputStream { @@ -1245,12 +1301,7 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_convertCudfToArrowTable(JNIEnv cudf::jni::auto_set_device(env); std::unique_ptr> result( new std::shared_ptr(nullptr)); - auto column_metadata = std::vector{}; - column_metadata.reserve(state->column_names.size()); - std::transform(std::begin(state->column_names), std::end(state->column_names), - std::back_inserter(column_metadata), - [](auto const &column_name) { return cudf::column_metadata{column_name}; }); - *result = cudf::to_arrow(*tview, column_metadata); + *result = cudf::to_arrow(*tview, state->get_column_metadata(*tview)); if (!result->get()) { return 0; } diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index 88196a4112a..625260b255f 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -4056,15 +4056,38 @@ void testTableBasedFilter() { } private Table getExpectedFileTable() { - return new TestBuilder() - .column(true, false, false, true, false) - .column(5, 1, 0, 2, 7) - .column(new Byte[]{2, 3, 4, 5, 9}) - .column(3l, 9l, 4l, 2l, 20l) - .column("this", "is", "a", "test", "string") - .column(1.0f, 3.5f, 5.9f, 7.1f, 9.8f) - .column(5.0d, 9.5d, 0.9d, 7.23d, 2.8d) - .build(); + return getExpectedFileTable(false); + } + + private Table getExpectedFileTable(boolean withNestedColumns) { + TestBuilder tb = new TestBuilder() + .column(true, false, false, true, false) + .column(5, 1, 0, 2, 7) + .column(new Byte[]{2, 3, 4, 5, 9}) + .column(3l, 9l, 4l, 2l, 20l) + .column("this", "is", "a", "test", "string") + .column(1.0f, 3.5f, 5.9f, 7.1f, 9.8f) + .column(5.0d, 9.5d, 0.9d, 7.23d, 2.8d); + if (withNestedColumns) { + StructType nestedType = new StructType(true, + new BasicType(false, DType.INT32), new BasicType(false, DType.STRING)); + tb.column(nestedType, + struct(1, "k1"), struct(2, "k2"), struct(3, "k3"), + struct(4, "k4"), new HostColumnVector.StructData((List) null)) + .column(new ListType(false, new BasicType(false, DType.INT32)), + Arrays.asList(1, 2), + Arrays.asList(3, 4), + Arrays.asList(5), + Arrays.asList(6, 7), + Arrays.asList(8, 9, 10)) + .column(new ListType(false, nestedType), + Arrays.asList(struct(1, "k1"), struct(2, "k2"), struct(3, "k3")), + Arrays.asList(struct(4, "k4"), struct(5, "k5")), + Arrays.asList(struct(6, "k6")), + Arrays.asList(new HostColumnVector.StructData((List) null)), + Arrays.asList()); + } + return tb.build(); } private Table getExpectedFileTableWithDecimals() { @@ -4272,10 +4295,13 @@ void testArrowIPCWriteToFileWithNamesAndMetadata() throws IOException { @Test void testArrowIPCWriteToBufferChunked() { - try (Table table0 = getExpectedFileTable(); + try (Table table0 = getExpectedFileTable(true); MyBufferConsumer consumer = new MyBufferConsumer()) { ArrowIPCWriterOptions options = ArrowIPCWriterOptions.builder() .withColumnNames("first", "second", "third", "fourth", "fifth", "sixth", "seventh") + .withColumnNames("eighth", "eighth_id", "eighth_name") + .withColumnNames("ninth") + .withColumnNames("tenth", "child_id", "child_name") .build(); try (TableWriter writer = Table.writeArrowIPCChunked(options, consumer)) { writer.write(table0);