Skip to content

Commit

Permalink
JNI: Pass names of children struct columns to native Arrow IPC writer (
Browse files Browse the repository at this point in the history
…#7598)

This PR is to add the support of building the structure of column metadata from the flattened column names according to the table schema.
Since the children column metadata is required when converting cudf tables to arrow tables.

Also updating the related unit tests.

closes #7570

Signed-off-by: Firestarman <[email protected]>

Authors:
  - Liangcai Li (@firestarman)

Approvers:
  - Jason Lowe (@jlowe)

URL: #7598
  • Loading branch information
firestarman authored Mar 20, 2021
1 parent 217d702 commit cdd44d2
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 16 deletions.
58 changes: 58 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ArrowIPCWriterOptions.java
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,64 @@ public Builder withCallback(DoneOnGpu callback) {
return this;
}

/**
* Add the name(s) for nullable column(s).
*
* Please note the column names of the nested struct columns should be flattened in sequence.
* For examples,
* <pre>
* A table with an int column and a struct column:
* ["int_col", "struct_col":{"field_1", "field_2"}]
* output:
* ["int_col", "struct_col", "field_1", "field_2"]
*
* A table with an int column and a list of non-nested type column:
* ["int_col", "list_col":[]]
* output:
* ["int_col", "list_col"]
*
* A table with an int column and a list of struct column:
* ["int_col", "list_struct_col":[{"field_1", "field_2"}]]
* output:
* ["int_col", "list_struct_col", "field_1", "field_2"]
* </pre>
*
* @param columnNames The column names corresponding to the written table(s).
*/
@Override
public Builder withColumnNames(String... columnNames) {
return super.withColumnNames(columnNames);
}

/**
* Add the name(s) for non-nullable column(s).
*
* Please note the column names of the nested struct columns should be flattened in sequence.
* For examples,
* <pre>
* A table with an int column and a struct column:
* ["int_col", "struct_col":{"field_1", "field_2"}]
* output:
* ["int_col", "struct_col", "field_1", "field_2"]
*
* A table with an int column and a list of non-nested type column:
* ["int_col", "list_col":[]]
* output:
* ["int_col", "list_col"]
*
* A table with an int column and a list of struct column:
* ["int_col", "list_struct_col":[{"field_1", "field_2"}]]
* output:
* ["int_col", "list_struct_col", "field_1", "field_2"]
* </pre>
*
* @param columnNames The column names corresponding to the written table(s).
*/
@Override
public Builder withNotNullableColumnNames(String... columnNames) {
return super.withNotNullableColumnNames(columnNames);
}

public ArrowIPCWriterOptions build() {
return new ArrowIPCWriterOptions(this);
}
Expand Down
63 changes: 57 additions & 6 deletions java/src/main/native/src/TableJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,12 +212,15 @@ class native_arrow_ipc_writer_handle final {
const std::shared_ptr<arrow::io::OutputStream> &sink)
: initialized(false), column_names(col_names), file_name(""), sink(sink) {}

private:
bool initialized;
std::vector<std::string> column_names;
std::vector<cudf::column_metadata> columns_meta;
std::string file_name;
std::shared_ptr<arrow::io::OutputStream> sink;
std::shared_ptr<arrow::ipc::RecordBatchWriter> writer;

public:
void write(std::shared_ptr<arrow::Table> &arrow_tab, int64_t max_chunk) {
if (!initialized) {
if (!sink) {
Expand Down Expand Up @@ -246,6 +249,59 @@ class native_arrow_ipc_writer_handle final {
}
initialized = false;
}

std::vector<cudf::column_metadata> get_column_metadata(const cudf::table_view& tview) {
if (!column_names.empty() && columns_meta.empty()) {
// Rebuild the structure of column meta according to table schema.
// All the tables written by this writer should share the same schema,
// so build column metadata only once.
columns_meta.reserve(tview.num_columns());
size_t idx = 0;
for (auto itr = tview.begin(); itr < tview.end(); ++itr) {
// It should consume the column names only when a column is
// - type of struct, or
// - not a child.
columns_meta.push_back(build_one_column_meta(*itr, idx));
}
if (idx < column_names.size()) {
throw cudf::jni::jni_exception("Too many column names are provided.");
}
}
return columns_meta;
}

private:
cudf::column_metadata build_one_column_meta(const cudf::column_view& cview, size_t& idx,
const bool consume_name = true) {
auto col_meta = cudf::column_metadata{};
if (consume_name) {
col_meta.name = get_column_name(idx++);
}
// Process children
if (cview.type().id() == cudf::type_id::LIST) {
// list type:
// - requires a stub metadata for offset column(index: 0).
// - does not require a name for the child column(index 1).
col_meta.children_meta = {{}, build_one_column_meta(cview.child(1), idx, false)};
} else if (cview.type().id() == cudf::type_id::STRUCT) {
// struct type always consumes the column names.
col_meta.children_meta.reserve(cview.num_children());
for (auto itr = cview.child_begin(); itr < cview.child_end(); ++itr) {
col_meta.children_meta.push_back(build_one_column_meta(*itr, idx));
}
} else if (cview.type().id() == cudf::type_id::DICTIONARY32) {
// not supported yet in JNI, nested type?
throw cudf::jni::jni_exception("Unsupported type 'DICTIONARY32'");
}
return col_meta;
}

std::string& get_column_name(const size_t idx) {
if (idx < 0 || idx >= column_names.size()) {
throw cudf::jni::jni_exception("Missing names for columns or nested struct columns");
}
return column_names[idx];
}
};

class jni_arrow_output_stream final : public arrow::io::OutputStream {
Expand Down Expand Up @@ -1262,12 +1318,7 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_convertCudfToArrowTable(JNIEnv
cudf::jni::auto_set_device(env);
std::unique_ptr<std::shared_ptr<arrow::Table>> result(
new std::shared_ptr<arrow::Table>(nullptr));
auto column_metadata = std::vector<cudf::column_metadata>{};
column_metadata.reserve(state->column_names.size());
std::transform(std::begin(state->column_names), std::end(state->column_names),
std::back_inserter(column_metadata),
[](auto const &column_name) { return cudf::column_metadata{column_name}; });
*result = cudf::to_arrow(*tview, column_metadata);
*result = cudf::to_arrow(*tview, state->get_column_metadata(*tview));
if (!result->get()) {
return 0;
}
Expand Down
46 changes: 36 additions & 10 deletions java/src/test/java/ai/rapids/cudf/TableTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4123,15 +4123,38 @@ void testTableBasedFilter() {
}

private Table getExpectedFileTable() {
return new TestBuilder()
.column(true, false, false, true, false)
.column(5, 1, 0, 2, 7)
.column(new Byte[]{2, 3, 4, 5, 9})
.column(3l, 9l, 4l, 2l, 20l)
.column("this", "is", "a", "test", "string")
.column(1.0f, 3.5f, 5.9f, 7.1f, 9.8f)
.column(5.0d, 9.5d, 0.9d, 7.23d, 2.8d)
.build();
return getExpectedFileTable(false);
}

private Table getExpectedFileTable(boolean withNestedColumns) {
TestBuilder tb = new TestBuilder()
.column(true, false, false, true, false)
.column(5, 1, 0, 2, 7)
.column(new Byte[]{2, 3, 4, 5, 9})
.column(3l, 9l, 4l, 2l, 20l)
.column("this", "is", "a", "test", "string")
.column(1.0f, 3.5f, 5.9f, 7.1f, 9.8f)
.column(5.0d, 9.5d, 0.9d, 7.23d, 2.8d);
if (withNestedColumns) {
StructType nestedType = new StructType(true,
new BasicType(false, DType.INT32), new BasicType(false, DType.STRING));
tb.column(nestedType,
struct(1, "k1"), struct(2, "k2"), struct(3, "k3"),
struct(4, "k4"), new HostColumnVector.StructData((List) null))
.column(new ListType(false, new BasicType(false, DType.INT32)),
Arrays.asList(1, 2),
Arrays.asList(3, 4),
Arrays.asList(5),
Arrays.asList(6, 7),
Arrays.asList(8, 9, 10))
.column(new ListType(false, nestedType),
Arrays.asList(struct(1, "k1"), struct(2, "k2"), struct(3, "k3")),
Arrays.asList(struct(4, "k4"), struct(5, "k5")),
Arrays.asList(struct(6, "k6")),
Arrays.asList(new HostColumnVector.StructData((List) null)),
Arrays.asList());
}
return tb.build();
}

private Table getExpectedFileTableWithDecimals() {
Expand Down Expand Up @@ -4332,10 +4355,13 @@ void testArrowIPCWriteToFileWithNamesAndMetadata() throws IOException {

@Test
void testArrowIPCWriteToBufferChunked() {
try (Table table0 = getExpectedFileTable();
try (Table table0 = getExpectedFileTable(true);
MyBufferConsumer consumer = new MyBufferConsumer()) {
ArrowIPCWriterOptions options = ArrowIPCWriterOptions.builder()
.withColumnNames("first", "second", "third", "fourth", "fifth", "sixth", "seventh")
.withColumnNames("eighth", "eighth_id", "eighth_name")
.withColumnNames("ninth")
.withColumnNames("tenth", "child_id", "child_name")
.build();
try (TableWriter writer = Table.writeArrowIPCChunked(options, consumer)) {
writer.write(table0);
Expand Down

0 comments on commit cdd44d2

Please sign in to comment.