diff --git a/java/src/main/java/ai/rapids/cudf/ArrowIPCWriterOptions.java b/java/src/main/java/ai/rapids/cudf/ArrowIPCWriterOptions.java
index 298e99b059d..ee5ae094b29 100644
--- a/java/src/main/java/ai/rapids/cudf/ArrowIPCWriterOptions.java
+++ b/java/src/main/java/ai/rapids/cudf/ArrowIPCWriterOptions.java
@@ -67,6 +67,64 @@ public Builder withCallback(DoneOnGpu callback) {
return this;
}
+ /**
+ * Add the name(s) for nullable column(s).
+ *
+ * Please note the column names of the nested struct columns should be flattened in sequence.
+ * For examples,
+ *
+ * A table with an int column and a struct column:
+ * ["int_col", "struct_col":{"field_1", "field_2"}]
+ * output:
+ * ["int_col", "struct_col", "field_1", "field_2"]
+ *
+ * A table with an int column and a list of non-nested type column:
+ * ["int_col", "list_col":[]]
+ * output:
+ * ["int_col", "list_col"]
+ *
+ * A table with an int column and a list of struct column:
+ * ["int_col", "list_struct_col":[{"field_1", "field_2"}]]
+ * output:
+ * ["int_col", "list_struct_col", "field_1", "field_2"]
+ *
+ *
+ * @param columnNames The column names corresponding to the written table(s).
+ */
+ @Override
+ public Builder withColumnNames(String... columnNames) {
+ return super.withColumnNames(columnNames);
+ }
+
+ /**
+ * Add the name(s) for non-nullable column(s).
+ *
+ * Please note the column names of the nested struct columns should be flattened in sequence.
+ * For examples,
+ *
+ * A table with an int column and a struct column:
+ * ["int_col", "struct_col":{"field_1", "field_2"}]
+ * output:
+ * ["int_col", "struct_col", "field_1", "field_2"]
+ *
+ * A table with an int column and a list of non-nested type column:
+ * ["int_col", "list_col":[]]
+ * output:
+ * ["int_col", "list_col"]
+ *
+ * A table with an int column and a list of struct column:
+ * ["int_col", "list_struct_col":[{"field_1", "field_2"}]]
+ * output:
+ * ["int_col", "list_struct_col", "field_1", "field_2"]
+ *
+ *
+ * @param columnNames The column names corresponding to the written table(s).
+ */
+ @Override
+ public Builder withNotNullableColumnNames(String... columnNames) {
+ return super.withNotNullableColumnNames(columnNames);
+ }
+
public ArrowIPCWriterOptions build() {
return new ArrowIPCWriterOptions(this);
}
diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp
index e051f68be4e..43616ea413d 100644
--- a/java/src/main/native/src/TableJni.cpp
+++ b/java/src/main/native/src/TableJni.cpp
@@ -211,12 +211,15 @@ class native_arrow_ipc_writer_handle final {
const std::shared_ptr &sink)
: initialized(false), column_names(col_names), file_name(""), sink(sink) {}
+private:
bool initialized;
std::vector column_names;
+ std::vector columns_meta;
std::string file_name;
std::shared_ptr sink;
std::shared_ptr writer;
+public:
void write(std::shared_ptr &arrow_tab, int64_t max_chunk) {
if (!initialized) {
if (!sink) {
@@ -245,6 +248,59 @@ class native_arrow_ipc_writer_handle final {
}
initialized = false;
}
+
+ std::vector get_column_metadata(const cudf::table_view& tview) {
+ if (!column_names.empty() && columns_meta.empty()) {
+ // Rebuild the structure of column meta according to table schema.
+ // All the tables written by this writer should share the same schema,
+ // so build column metadata only once.
+ columns_meta.reserve(tview.num_columns());
+ size_t idx = 0;
+ for (auto itr = tview.begin(); itr < tview.end(); ++itr) {
+ // It should consume the column names only when a column is
+ // - type of struct, or
+ // - not a child.
+ columns_meta.push_back(build_one_column_meta(*itr, idx));
+ }
+ if (idx < column_names.size()) {
+ throw cudf::jni::jni_exception("Too many column names are provided.");
+ }
+ }
+ return columns_meta;
+ }
+
+private:
+ cudf::column_metadata build_one_column_meta(const cudf::column_view& cview, size_t& idx,
+ const bool consume_name = true) {
+ auto col_meta = cudf::column_metadata{};
+ if (consume_name) {
+ col_meta.name = get_column_name(idx++);
+ }
+ // Process children
+ if (cview.type().id() == cudf::type_id::LIST) {
+ // list type:
+ // - requires a stub metadata for offset column(index: 0).
+ // - does not require a name for the child column(index 1).
+ col_meta.children_meta = {{}, build_one_column_meta(cview.child(1), idx, false)};
+ } else if (cview.type().id() == cudf::type_id::STRUCT) {
+ // struct type always consumes the column names.
+ col_meta.children_meta.reserve(cview.num_children());
+ for (auto itr = cview.child_begin(); itr < cview.child_end(); ++itr) {
+ col_meta.children_meta.push_back(build_one_column_meta(*itr, idx));
+ }
+ } else if (cview.type().id() == cudf::type_id::DICTIONARY32) {
+ // not supported yet in JNI, nested type?
+ throw cudf::jni::jni_exception("Unsupported type 'DICTIONARY32'");
+ }
+ return col_meta;
+ }
+
+ std::string& get_column_name(const size_t idx) {
+ if (idx < 0 || idx >= column_names.size()) {
+ throw cudf::jni::jni_exception("Missing names for columns or nested struct columns");
+ }
+ return column_names[idx];
+ }
};
class jni_arrow_output_stream final : public arrow::io::OutputStream {
@@ -1245,12 +1301,7 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_convertCudfToArrowTable(JNIEnv
cudf::jni::auto_set_device(env);
std::unique_ptr> result(
new std::shared_ptr(nullptr));
- auto column_metadata = std::vector{};
- column_metadata.reserve(state->column_names.size());
- std::transform(std::begin(state->column_names), std::end(state->column_names),
- std::back_inserter(column_metadata),
- [](auto const &column_name) { return cudf::column_metadata{column_name}; });
- *result = cudf::to_arrow(*tview, column_metadata);
+ *result = cudf::to_arrow(*tview, state->get_column_metadata(*tview));
if (!result->get()) {
return 0;
}
diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java
index 88196a4112a..625260b255f 100644
--- a/java/src/test/java/ai/rapids/cudf/TableTest.java
+++ b/java/src/test/java/ai/rapids/cudf/TableTest.java
@@ -4056,15 +4056,38 @@ void testTableBasedFilter() {
}
private Table getExpectedFileTable() {
- return new TestBuilder()
- .column(true, false, false, true, false)
- .column(5, 1, 0, 2, 7)
- .column(new Byte[]{2, 3, 4, 5, 9})
- .column(3l, 9l, 4l, 2l, 20l)
- .column("this", "is", "a", "test", "string")
- .column(1.0f, 3.5f, 5.9f, 7.1f, 9.8f)
- .column(5.0d, 9.5d, 0.9d, 7.23d, 2.8d)
- .build();
+ return getExpectedFileTable(false);
+ }
+
+ private Table getExpectedFileTable(boolean withNestedColumns) {
+ TestBuilder tb = new TestBuilder()
+ .column(true, false, false, true, false)
+ .column(5, 1, 0, 2, 7)
+ .column(new Byte[]{2, 3, 4, 5, 9})
+ .column(3l, 9l, 4l, 2l, 20l)
+ .column("this", "is", "a", "test", "string")
+ .column(1.0f, 3.5f, 5.9f, 7.1f, 9.8f)
+ .column(5.0d, 9.5d, 0.9d, 7.23d, 2.8d);
+ if (withNestedColumns) {
+ StructType nestedType = new StructType(true,
+ new BasicType(false, DType.INT32), new BasicType(false, DType.STRING));
+ tb.column(nestedType,
+ struct(1, "k1"), struct(2, "k2"), struct(3, "k3"),
+ struct(4, "k4"), new HostColumnVector.StructData((List) null))
+ .column(new ListType(false, new BasicType(false, DType.INT32)),
+ Arrays.asList(1, 2),
+ Arrays.asList(3, 4),
+ Arrays.asList(5),
+ Arrays.asList(6, 7),
+ Arrays.asList(8, 9, 10))
+ .column(new ListType(false, nestedType),
+ Arrays.asList(struct(1, "k1"), struct(2, "k2"), struct(3, "k3")),
+ Arrays.asList(struct(4, "k4"), struct(5, "k5")),
+ Arrays.asList(struct(6, "k6")),
+ Arrays.asList(new HostColumnVector.StructData((List) null)),
+ Arrays.asList());
+ }
+ return tb.build();
}
private Table getExpectedFileTableWithDecimals() {
@@ -4272,10 +4295,13 @@ void testArrowIPCWriteToFileWithNamesAndMetadata() throws IOException {
@Test
void testArrowIPCWriteToBufferChunked() {
- try (Table table0 = getExpectedFileTable();
+ try (Table table0 = getExpectedFileTable(true);
MyBufferConsumer consumer = new MyBufferConsumer()) {
ArrowIPCWriterOptions options = ArrowIPCWriterOptions.builder()
.withColumnNames("first", "second", "third", "fourth", "fifth", "sixth", "seventh")
+ .withColumnNames("eighth", "eighth_id", "eighth_name")
+ .withColumnNames("ninth")
+ .withColumnNames("tenth", "child_id", "child_name")
.build();
try (TableWriter writer = Table.writeArrowIPCChunked(options, consumer)) {
writer.write(table0);