Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JNI: Pass names of children struct columns to native Arrow IPC writer [skip ci] #7598

Merged
merged 16 commits into from
Mar 20, 2021
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 47 additions & 6 deletions java/src/main/native/src/TableJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,12 +211,15 @@ class native_arrow_ipc_writer_handle final {
const std::shared_ptr<arrow::io::OutputStream> &sink)
: initialized(false), column_names(col_names), file_name(""), sink(sink) {}

private:
bool initialized;
std::vector<std::string> column_names;
std::vector<cudf::column_metadata> columns_meta;
std::string file_name;
std::shared_ptr<arrow::io::OutputStream> sink;
std::shared_ptr<arrow::ipc::RecordBatchWriter> writer;

public:
void write(std::shared_ptr<arrow::Table> &arrow_tab, int64_t max_chunk) {
if (!initialized) {
if (!sink) {
Expand Down Expand Up @@ -245,6 +248,49 @@ class native_arrow_ipc_writer_handle final {
}
initialized = false;
}

std::vector<cudf::column_metadata> get_column_metadata(const table_view& tview) {
if (!column_names.empty() && columns_meta.empty()) {
// Rebuild the structure of column meta according to table schema.
// All the tables written by this writer should share the same schema,
// so build column metadata only once.
columns_meta.reserve(tview.num_columns());
size_t idx = 0;
for (auto itr = tview.begin(); itr < tview.end(); ++itr) {
columns_meta.push_back(std::move(build_one_column_meta(*itr, idx)));
jlowe marked this conversation as resolved.
Show resolved Hide resolved
idx ++;
}
}
jlowe marked this conversation as resolved.
Show resolved Hide resolved
return columns_meta;
}

private:
// Still return an oject instead of being passed as an out argument, even
// `column_metadata` has no move constructor and would be copied.
jlowe marked this conversation as resolved.
Show resolved Hide resolved
cudf::column_metadata build_one_column_meta(const column_view& cview, size_t& idx) {
auto col_meta = cudf::column_metadata{get_column_name(idx)};
if (cview.type().id() == cudf::type_id::LIST) {
// list type requires a stub metadata for offset column, index is 0.
col_meta.children_meta = {{}, build_one_column_meta(cview.child(1), ++idx)};
} else if (cview.type().id() == cudf::type_id::STRUCT) {
// struct type
col_meta.children_meta.reserve(cview.num_children());
for (auto itr = cview.child_begin(); itr < cview.child_end(); ++itr) {
col_meta.children_meta.push_back(std::move(build_one_column_meta(*itr, ++idx)));
}
} else if (cview.type().id() == cudf::type_id::DICTIONARY32) {
// not supported yet in JNI, nested type?
jlowe marked this conversation as resolved.
Show resolved Hide resolved
}
return col_meta;
}

std::string& get_column_name(const size_t idx) {
if (idx < 0 || idx >= column_names.size()) {
throw cudf::jni::jni_exception(
"Missing column names for struct columns or nested struct columns");
}
return column_names[idx];
}
};

class jni_arrow_output_stream final : public arrow::io::OutputStream {
Expand Down Expand Up @@ -1245,12 +1291,7 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_convertCudfToArrowTable(JNIEnv
cudf::jni::auto_set_device(env);
std::unique_ptr<std::shared_ptr<arrow::Table>> result(
new std::shared_ptr<arrow::Table>(nullptr));
auto column_metadata = std::vector<cudf::column_metadata>{};
column_metadata.reserve(state->column_names.size());
std::transform(std::begin(state->column_names), std::end(state->column_names),
std::back_inserter(column_metadata),
[](auto const &column_name) { return cudf::column_metadata{column_name}; });
*result = cudf::to_arrow(*tview, column_metadata);
*result = cudf::to_arrow(*tview, state->get_column_metadata(*tview));
if (!result->get()) {
return 0;
}
Expand Down
46 changes: 36 additions & 10 deletions java/src/test/java/ai/rapids/cudf/TableTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4056,15 +4056,38 @@ void testTableBasedFilter() {
}

private Table getExpectedFileTable() {
return new TestBuilder()
.column(true, false, false, true, false)
.column(5, 1, 0, 2, 7)
.column(new Byte[]{2, 3, 4, 5, 9})
.column(3l, 9l, 4l, 2l, 20l)
.column("this", "is", "a", "test", "string")
.column(1.0f, 3.5f, 5.9f, 7.1f, 9.8f)
.column(5.0d, 9.5d, 0.9d, 7.23d, 2.8d)
.build();
return getExpectedFileTable(false);
}

private Table getExpectedFileTable(boolean withNestedColumns) {
TestBuilder tb = new TestBuilder()
.column(true, false, false, true, false)
.column(5, 1, 0, 2, 7)
.column(new Byte[]{2, 3, 4, 5, 9})
.column(3l, 9l, 4l, 2l, 20l)
.column("this", "is", "a", "test", "string")
.column(1.0f, 3.5f, 5.9f, 7.1f, 9.8f)
.column(5.0d, 9.5d, 0.9d, 7.23d, 2.8d);
if (withNestedColumns) {
StructType nestedType = new StructType(true,
new BasicType(false, DType.INT32), new BasicType(false, DType.STRING));
tb.column(nestedType,
struct(1, "k1"), struct(2, "k2"), struct(3, "k3"),
struct(4, "k4"), new HostColumnVector.StructData((List) null))
.column(new ListType(false, new BasicType(false, DType.INT32)),
Arrays.asList(1, 2),
Arrays.asList(3, 4),
Arrays.asList(5),
Arrays.asList(6, 7),
Arrays.asList(8, 9, 10))
.column(new ListType(false, nestedType),
Arrays.asList(struct(1, "k1"), struct(2, "k2"), struct(3, "k3")),
Arrays.asList(struct(4, "k4"), struct(5, "k5")),
Arrays.asList(struct(6, "k6")),
Arrays.asList(new HostColumnVector.StructData((List) null)),
Arrays.asList());
}
return tb.build();
}

private Table getExpectedFileTableWithDecimals() {
Expand Down Expand Up @@ -4272,10 +4295,13 @@ void testArrowIPCWriteToFileWithNamesAndMetadata() throws IOException {

@Test
void testArrowIPCWriteToBufferChunked() {
try (Table table0 = getExpectedFileTable();
try (Table table0 = getExpectedFileTable(true);
MyBufferConsumer consumer = new MyBufferConsumer()) {
ArrowIPCWriterOptions options = ArrowIPCWriterOptions.builder()
.withColumnNames("first", "second", "third", "fourth", "fifth", "sixth", "seventh")
.withColumnNames("eighth", "eighth_id", "eighth_name")
.withColumnNames("ninth", "ninth_child")
.withColumnNames("tenth", "tenth_child", "child_id", "child_name")
.build();
try (TableWriter writer = Table.writeArrowIPCChunked(options, consumer)) {
writer.write(table0);
Expand Down