diff --git a/CHANGELOG.md b/CHANGELOG.md index 799bd94e62a..92ebde6e01a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -91,6 +91,7 @@ - PR #5559 Java APIs for missing date/time operators - PR #5582 Add support for axis and other parameters to `DataFrame.sort_index` and fix other bunch of issues. - PR #5562 Add missing join type for java +- PR #5584 Refactor `CompactProtocolReader::InitSchema` - PR #5591 Add `__arrow_array__` protocol and raise a descriptive error message ## Bug Fixes diff --git a/cpp/src/io/parquet/parquet.cpp b/cpp/src/io/parquet/parquet.cpp index e4aa02b4401..2c15972af10 100644 --- a/cpp/src/io/parquet/parquet.cpp +++ b/cpp/src/io/parquet/parquet.cpp @@ -16,6 +16,8 @@ #include "parquet.h" +#include + namespace cudf { namespace io { namespace parquet { @@ -287,34 +289,31 @@ PARQUET_END_STRUCT() **/ bool CompactProtocolReader::InitSchema(FileMetaData *md) { - int final_pos = WalkSchema(md->schema); - if (final_pos != md->schema.size()) { return false; } - - // Map columns to schema - for (size_t i = 0; i < md->row_groups.size(); i++) { - RowGroup *g = &md->row_groups[i]; - int cur = 0; - for (size_t j = 0; j < g->columns.size(); j++) { - ColumnChunk *col = &g->columns[j]; - int parent = 0; // root of schema - for (size_t k = 0; k < col->meta_data.path_in_schema.size(); k++) { - bool found = false; - int pos = cur + 1, maxpos = (int)md->schema.size(); - for (int l = maxpos; l > 0; --l) { - if (pos >= maxpos) { - pos = 0; // wrap around - } - if (md->schema[pos].parent_idx == parent && - md->schema[pos].name == col->meta_data.path_in_schema[k]) { - cur = pos; - found = true; - break; - } - pos++; - } - if (!found) { return false; } - col->schema_idx = cur; - parent = cur; + if (WalkSchema(md->schema) != md->schema.size()) return false; + + /* Inside FileMetaData, there is a std::vector of RowGroups and each RowGroup contains a + * a std::vector of ColumnChunks. Each ColumnChunk has a member ColumnMetaData, which contains + * a std::vector of std::strings representing paths. The purpose of the code below is to set the + * schema_idx of each column of each row to it corresonding row_group. This is effectively + * mapping the columns to the schema. + */ + for (auto &row_group : md->row_groups) { + int current_row_group = 0; + for (auto &column : row_group.columns) { + int parent = 0; // root of schema + for (auto const &path : column.meta_data.path_in_schema) { + auto const it = [&] { + // find_if starting at (current_row_group + 1) and then wrapping + auto schema = [&](auto const &e) { return e.parent_idx == parent && e.name == path; }; + auto mid = md->schema.cbegin() + current_row_group + 1; + auto it = std::find_if(mid, md->schema.cend(), schema); + if (it != md->schema.cend()) return it; + return std::find_if(md->schema.cbegin(), mid, schema); + }(); + if (it == md->schema.cend()) return false; + current_row_group = std::distance(md->schema.cbegin(), it); + column.schema_idx = current_row_group; + parent = current_row_group; } } }