From 40d4cc5565f600864c3b16f30d3d26fd4904deaf Mon Sep 17 00:00:00 2001 From: Ed Seidl Date: Wed, 20 Sep 2023 11:03:44 -0700 Subject: [PATCH] Refactor parquet thrift reader (#14097) Refactors the current `CompactProtocolReader` used to parse parquet file metadata. The main goal of the refactor is to allow easier use of `std::optional` fields in the thrift structs to prevent situations as in #14024 where an optional field is an empty string. The writer cannot distinguish between present-but-empty and not-present, so chooses the latter when writing the field. This PR adds a `ParquetFieldOptional` functor that can wrap the other field functors, obviating the need to write a new optional functor for each type. Authors: - Ed Seidl (https://github.com/etseidl) Approvers: - Vukasin Milovanovic (https://github.com/vuule) - Yunsong Wang (https://github.com/PointKernel) URL: https://github.com/rapidsai/cudf/pull/14097 --- .../io/parquet/compact_protocol_reader.cpp | 691 +++++++++++++++--- .../io/parquet/compact_protocol_reader.hpp | 586 +-------------- .../io/parquet/compact_protocol_writer.cpp | 30 +- .../io/parquet/compact_protocol_writer.hpp | 3 + cpp/src/io/parquet/parquet.hpp | 18 +- cpp/src/io/parquet/parquet_common.hpp | 2 +- cpp/src/io/parquet/writer_impl.cu | 38 +- 7 files changed, 662 insertions(+), 706 deletions(-) diff --git a/cpp/src/io/parquet/compact_protocol_reader.cpp b/cpp/src/io/parquet/compact_protocol_reader.cpp index ae11af92f78..5c7b8ca3f8c 100644 --- a/cpp/src/io/parquet/compact_protocol_reader.cpp +++ b/cpp/src/io/parquet/compact_protocol_reader.cpp @@ -18,27 +18,474 @@ #include #include +#include #include namespace cudf { namespace io { namespace parquet { -uint8_t const CompactProtocolReader::g_list2struct[16] = {0, - 1, - 2, - ST_FLD_BYTE, - ST_FLD_DOUBLE, - 5, - ST_FLD_I16, - 7, - ST_FLD_I32, - 9, - ST_FLD_I64, - ST_FLD_BINARY, - ST_FLD_STRUCT, - ST_FLD_MAP, - ST_FLD_SET, - ST_FLD_LIST}; + +/** + * @brief Base class for parquet field functors. + * + * Holds the field value used by all of the specialized functors. + */ +class parquet_field { + private: + int _field_val; + + protected: + parquet_field(int f) : _field_val(f) {} + + public: + virtual ~parquet_field() = default; + int field() const { return _field_val; } +}; + +/** + * @brief Abstract base class for list functors. + */ +template +class parquet_field_list : public parquet_field { + private: + using read_func_type = std::function; + FieldType _expected_type; + read_func_type _read_value; + + protected: + std::vector& val; + + void bind_read_func(read_func_type fn) { _read_value = fn; } + + parquet_field_list(int f, std::vector& v, FieldType t) + : parquet_field(f), _expected_type(t), val(v) + { + } + + public: + inline bool operator()(CompactProtocolReader* cpr, int field_type) + { + if (field_type != ST_FLD_LIST) { return true; } + auto const [t, n] = cpr->get_listh(); + if (t != _expected_type) { return true; } + val.resize(n); + for (uint32_t i = 0; i < n; i++) { + if (_read_value(i, cpr)) { return true; } + } + return false; + } +}; + +/** + * @brief Functor to set value to bool read from CompactProtocolReader + * + * bool doesn't actually encode a value, we just use the field type to indicate true/false + * + * @return True if field type is not bool + */ +class parquet_field_bool : public parquet_field { + bool& val; + + public: + parquet_field_bool(int f, bool& v) : parquet_field(f), val(v) {} + + inline bool operator()(CompactProtocolReader* cpr, int field_type) + { + if (field_type != ST_FLD_TRUE && field_type != ST_FLD_FALSE) { return true; } + val = field_type == ST_FLD_TRUE; + return false; + } +}; + +/** + * @brief Functor to read a vector of booleans from CompactProtocolReader + * + * @return True if field types mismatch or if the process of reading a + * bool fails + */ +struct parquet_field_bool_list : public parquet_field_list { + parquet_field_bool_list(int f, std::vector& v) : parquet_field_list(f, v, ST_FLD_TRUE) + { + auto const read_value = [this](uint32_t i, CompactProtocolReader* cpr) { + auto const current_byte = cpr->getb(); + if (current_byte != ST_FLD_TRUE && current_byte != ST_FLD_FALSE) { return true; } + this->val[i] = current_byte == ST_FLD_TRUE; + return false; + }; + bind_read_func(read_value); + } +}; + +/** + * @brief Base type for a functor that reads an integer from CompactProtocolReader + * + * Assuming signed ints since the parquet spec does not use unsigned ints anywhere. + * + * @return True if there is a type mismatch + */ +template +class parquet_field_int : public parquet_field { + static constexpr bool is_byte = std::is_same_v; + + T& val; + + public: + parquet_field_int(int f, T& v) : parquet_field(f), val(v) {} + + inline bool operator()(CompactProtocolReader* cpr, int field_type) + { + if constexpr (is_byte) { + val = cpr->getb(); + } else { + val = cpr->get_zigzag(); + } + return (field_type != EXPECTED_TYPE); + } +}; + +using parquet_field_int8 = parquet_field_int; +using parquet_field_int32 = parquet_field_int; +using parquet_field_int64 = parquet_field_int; + +/** + * @brief Functor to read a vector of integers from CompactProtocolReader + * + * @return True if field types mismatch or if the process of reading an + * integer fails + */ +template +struct parquet_field_int_list : public parquet_field_list { + parquet_field_int_list(int f, std::vector& v) : parquet_field_list(f, v, EXPECTED_TYPE) + { + auto const read_value = [this](uint32_t i, CompactProtocolReader* cpr) { + this->val[i] = cpr->get_zigzag(); + return false; + }; + this->bind_read_func(read_value); + } +}; + +using parquet_field_int64_list = parquet_field_int_list; + +/** + * @brief Functor to read a string from CompactProtocolReader + * + * @return True if field type mismatches or if size of string exceeds bounds + * of the CompactProtocolReader + */ +class parquet_field_string : public parquet_field { + std::string& val; + + public: + parquet_field_string(int f, std::string& v) : parquet_field(f), val(v) {} + + inline bool operator()(CompactProtocolReader* cpr, int field_type) + { + if (field_type != ST_FLD_BINARY) { return true; } + auto const n = cpr->get_u32(); + if (n < static_cast(cpr->m_end - cpr->m_cur)) { + val.assign(reinterpret_cast(cpr->m_cur), n); + cpr->m_cur += n; + return false; + } else { + return true; + } + } +}; + +/** + * @brief Functor to read a vector of strings from CompactProtocolReader + * + * @return True if field types mismatch or if the process of reading a + * string fails + */ +struct parquet_field_string_list : public parquet_field_list { + parquet_field_string_list(int f, std::vector& v) + : parquet_field_list(f, v, ST_FLD_BINARY) + { + auto const read_value = [this](uint32_t i, CompactProtocolReader* cpr) { + auto const l = cpr->get_u32(); + if (l < static_cast(cpr->m_end - cpr->m_cur)) { + this->val[i].assign(reinterpret_cast(cpr->m_cur), l); + cpr->m_cur += l; + } else { + return true; + } + return false; + }; + bind_read_func(read_value); + } +}; + +/** + * @brief Functor to set value to enum read from CompactProtocolReader + * + * @return True if field type is not int32 + */ +template +class parquet_field_enum : public parquet_field { + Enum& val; + + public: + parquet_field_enum(int f, Enum& v) : parquet_field(f), val(v) {} + inline bool operator()(CompactProtocolReader* cpr, int field_type) + { + val = static_cast(cpr->get_i32()); + return (field_type != ST_FLD_I32); + } +}; + +/** + * @brief Functor to read a vector of enums from CompactProtocolReader + * + * @return True if field types mismatch or if the process of reading an + * enum fails + */ +template +struct parquet_field_enum_list : public parquet_field_list { + parquet_field_enum_list(int f, std::vector& v) : parquet_field_list(f, v, ST_FLD_I32) + { + auto const read_value = [this](uint32_t i, CompactProtocolReader* cpr) { + this->val[i] = static_cast(cpr->get_i32()); + return false; + }; + this->bind_read_func(read_value); + } +}; + +/** + * @brief Functor to read a structure from CompactProtocolReader + * + * @return True if field types mismatch or if the process of reading a + * struct fails + */ +template +class parquet_field_struct : public parquet_field { + T& val; + + public: + parquet_field_struct(int f, T& v) : parquet_field(f), val(v) {} + + inline bool operator()(CompactProtocolReader* cpr, int field_type) + { + return (field_type != ST_FLD_STRUCT || !(cpr->read(&val))); + } +}; + +/** + * @brief Functor to read optional structures in unions + * + * @return True if field types mismatch + */ +template +class parquet_field_union_struct : public parquet_field { + E& enum_val; + thrust::optional& val; // union structs are always wrapped in std::optional + + public: + parquet_field_union_struct(int f, E& ev, thrust::optional& v) + : parquet_field(f), enum_val(ev), val(v) + { + } + + inline bool operator()(CompactProtocolReader* cpr, int field_type) + { + T v; + bool const res = parquet_field_struct(field(), v).operator()(cpr, field_type); + if (!res) { + val = v; + enum_val = static_cast(field()); + } + return res; + } +}; + +/** + * @brief Functor to read empty structures in unions + * + * Added to avoid having to define read() functions for empty structs contained in unions. + * + * @return True if field types mismatch + */ +template +class parquet_field_union_enumerator : public parquet_field { + E& val; + + public: + parquet_field_union_enumerator(int f, E& v) : parquet_field(f), val(v) {} + + inline bool operator()(CompactProtocolReader* cpr, int field_type) + { + if (field_type != ST_FLD_STRUCT) { return true; } + cpr->skip_struct_field(field_type); + val = static_cast(field()); + return false; + } +}; + +/** + * @brief Functor to read a vector of structures from CompactProtocolReader + * + * @return True if field types mismatch or if the process of reading a + * struct fails + */ +template +struct parquet_field_struct_list : public parquet_field_list { + parquet_field_struct_list(int f, std::vector& v) : parquet_field_list(f, v, ST_FLD_STRUCT) + { + auto const read_value = [this](uint32_t i, CompactProtocolReader* cpr) { + if (not cpr->read(&this->val[i])) { return true; } + return false; + }; + this->bind_read_func(read_value); + } +}; + +// TODO(ets): replace current union handling (which mirrors thrift) to use std::optional fields +// in a struct +/** + * @brief Functor to read a union member from CompactProtocolReader + * + * @tparam is_empty True if tparam `T` type is empty type, else false. + * + * @return True if field types mismatch or if the process of reading a + * union member fails + */ +template +class ParquetFieldUnionFunctor : public parquet_field { + bool& is_set; + T& val; + + public: + ParquetFieldUnionFunctor(int f, bool& b, T& v) : parquet_field(f), is_set(b), val(v) {} + + inline bool operator()(CompactProtocolReader* cpr, int field_type) + { + if (field_type != ST_FLD_STRUCT) { + return true; + } else { + is_set = true; + return !cpr->read(&val); + } + } +}; + +template +class ParquetFieldUnionFunctor : public parquet_field { + bool& is_set; + T& val; + + public: + ParquetFieldUnionFunctor(int f, bool& b, T& v) : parquet_field(f), is_set(b), val(v) {} + + inline bool operator()(CompactProtocolReader* cpr, int field_type) + { + if (field_type != ST_FLD_STRUCT) { + return true; + } else { + is_set = true; + cpr->skip_struct_field(field_type); + return false; + } + } +}; + +template +ParquetFieldUnionFunctor> ParquetFieldUnion(int f, bool& b, T& v) +{ + return ParquetFieldUnionFunctor>(f, b, v); +} + +/** + * @brief Functor to read a binary from CompactProtocolReader + * + * @return True if field type mismatches or if size of binary exceeds bounds + * of the CompactProtocolReader + */ +class parquet_field_binary : public parquet_field { + std::vector& val; + + public: + parquet_field_binary(int f, std::vector& v) : parquet_field(f), val(v) {} + + inline bool operator()(CompactProtocolReader* cpr, int field_type) + { + if (field_type != ST_FLD_BINARY) { return true; } + auto const n = cpr->get_u32(); + if (n <= static_cast(cpr->m_end - cpr->m_cur)) { + val.resize(n); + val.assign(cpr->m_cur, cpr->m_cur + n); + cpr->m_cur += n; + return false; + } else { + return true; + } + } +}; + +/** + * @brief Functor to read a vector of binaries from CompactProtocolReader + * + * @return True if field types mismatch or if the process of reading a + * binary fails + */ +struct parquet_field_binary_list : public parquet_field_list> { + parquet_field_binary_list(int f, std::vector>& v) + : parquet_field_list(f, v, ST_FLD_BINARY) + { + auto const read_value = [this](uint32_t i, CompactProtocolReader* cpr) { + auto const l = cpr->get_u32(); + if (l <= static_cast(cpr->m_end - cpr->m_cur)) { + val[i].resize(l); + val[i].assign(cpr->m_cur, cpr->m_cur + l); + cpr->m_cur += l; + } else { + return true; + } + return false; + }; + bind_read_func(read_value); + } +}; + +/** + * @brief Functor to read a struct from CompactProtocolReader + * + * @return True if field type mismatches + */ +class parquet_field_struct_blob : public parquet_field { + std::vector& val; + + public: + parquet_field_struct_blob(int f, std::vector& v) : parquet_field(f), val(v) {} + inline bool operator()(CompactProtocolReader* cpr, int field_type) + { + if (field_type != ST_FLD_STRUCT) { return true; } + uint8_t const* const start = cpr->m_cur; + cpr->skip_struct_field(field_type); + if (cpr->m_cur > start) { val.assign(start, cpr->m_cur - 1); } + return false; + } +}; + +/** + * @brief functor to wrap functors for optional fields + */ +template +class parquet_field_optional : public parquet_field { + thrust::optional& val; + + public: + parquet_field_optional(int f, thrust::optional& v) : parquet_field(f), val(v) {} + + inline bool operator()(CompactProtocolReader* cpr, int field_type) + { + T v; + bool const res = FieldFunctor(field(), v).operator()(cpr, field_type); + if (!res) { val = v; } + return res; + } +}; /** * @brief Skips the number of bytes according to the specified struct type @@ -59,22 +506,21 @@ bool CompactProtocolReader::skip_struct_field(int t, int depth) case ST_FLD_BYTE: skip_bytes(1); break; case ST_FLD_DOUBLE: skip_bytes(8); break; case ST_FLD_BINARY: skip_bytes(get_u32()); break; - case ST_FLD_LIST: + case ST_FLD_LIST: [[fallthrough]]; case ST_FLD_SET: { - int c = getb(); - int n = c >> 4; - if (n == 0xf) n = get_i32(); - t = g_list2struct[c & 0xf]; - if (depth > 10) return false; - for (int32_t i = 0; i < n; i++) + auto const [t, n] = get_listh(); + if (depth > 10) { return false; } + for (uint32_t i = 0; i < n; i++) { skip_struct_field(t, depth + 1); + } } break; case ST_FLD_STRUCT: for (;;) { - int c = getb(); - t = c & 0xf; - if (!c) break; - if (depth > 10) return false; + int const c = getb(); + t = c & 0xf; + if (c == 0) { break; } // end of struct + if ((c & 0xf0) == 0) { get_i16(); } // field id is not a delta + if (depth > 10) { return false; } skip_struct_field(t, depth + 1); } break; @@ -125,11 +571,11 @@ inline bool function_builder(CompactProtocolReader* cpr, std::tuple int field = 0; while (true) { int const current_byte = cpr->getb(); - if (!current_byte) break; - int const field_delta = current_byte >> 4; - int const field_type = current_byte & 0xf; - field = field_delta ? field + field_delta : cpr->get_i16(); - bool exit_function = FunctionSwitchImpl::run(cpr, field_type, field, op); + if (!current_byte) { break; } + int const field_delta = current_byte >> 4; + int const field_type = current_byte & 0xf; + field = field_delta ? field + field_delta : cpr->get_i16(); + bool const exit_function = FunctionSwitchImpl::run(cpr, field_type, field, op); if (exit_function) { return false; } } return true; @@ -137,27 +583,30 @@ inline bool function_builder(CompactProtocolReader* cpr, std::tuple bool CompactProtocolReader::read(FileMetaData* f) { - auto op = std::make_tuple(ParquetFieldInt32(1, f->version), - ParquetFieldStructList(2, f->schema), - ParquetFieldInt64(3, f->num_rows), - ParquetFieldStructList(4, f->row_groups), - ParquetFieldStructList(5, f->key_value_metadata), - ParquetFieldString(6, f->created_by)); + using optional_list_column_order = + parquet_field_optional, parquet_field_struct_list>; + auto op = std::make_tuple(parquet_field_int32(1, f->version), + parquet_field_struct_list(2, f->schema), + parquet_field_int64(3, f->num_rows), + parquet_field_struct_list(4, f->row_groups), + parquet_field_struct_list(5, f->key_value_metadata), + parquet_field_string(6, f->created_by), + optional_list_column_order(7, f->column_orders)); return function_builder(this, op); } bool CompactProtocolReader::read(SchemaElement* s) { - auto op = std::make_tuple(ParquetFieldEnum(1, s->type), - ParquetFieldInt32(2, s->type_length), - ParquetFieldEnum(3, s->repetition_type), - ParquetFieldString(4, s->name), - ParquetFieldInt32(5, s->num_children), - ParquetFieldEnum(6, s->converted_type), - ParquetFieldInt32(7, s->decimal_scale), - ParquetFieldInt32(8, s->decimal_precision), - ParquetFieldOptionalInt32(9, s->field_id), - ParquetFieldStruct(10, s->logical_type)); + auto op = std::make_tuple(parquet_field_enum(1, s->type), + parquet_field_int32(2, s->type_length), + parquet_field_enum(3, s->repetition_type), + parquet_field_string(4, s->name), + parquet_field_int32(5, s->num_children), + parquet_field_enum(6, s->converted_type), + parquet_field_int32(7, s->decimal_scale), + parquet_field_int32(8, s->decimal_precision), + parquet_field_optional(9, s->field_id), + parquet_field_struct(10, s->logical_type)); return function_builder(this, op); } @@ -181,21 +630,21 @@ bool CompactProtocolReader::read(LogicalType* l) bool CompactProtocolReader::read(DecimalType* d) { - auto op = std::make_tuple(ParquetFieldInt32(1, d->scale), ParquetFieldInt32(2, d->precision)); + auto op = std::make_tuple(parquet_field_int32(1, d->scale), parquet_field_int32(2, d->precision)); return function_builder(this, op); } bool CompactProtocolReader::read(TimeType* t) { auto op = - std::make_tuple(ParquetFieldBool(1, t->isAdjustedToUTC), ParquetFieldStruct(2, t->unit)); + std::make_tuple(parquet_field_bool(1, t->isAdjustedToUTC), parquet_field_struct(2, t->unit)); return function_builder(this, op); } bool CompactProtocolReader::read(TimestampType* t) { auto op = - std::make_tuple(ParquetFieldBool(1, t->isAdjustedToUTC), ParquetFieldStruct(2, t->unit)); + std::make_tuple(parquet_field_bool(1, t->isAdjustedToUTC), parquet_field_struct(2, t->unit)); return function_builder(this, op); } @@ -209,123 +658,129 @@ bool CompactProtocolReader::read(TimeUnit* u) bool CompactProtocolReader::read(IntType* i) { - auto op = std::make_tuple(ParquetFieldInt8(1, i->bitWidth), ParquetFieldBool(2, i->isSigned)); + auto op = std::make_tuple(parquet_field_int8(1, i->bitWidth), parquet_field_bool(2, i->isSigned)); return function_builder(this, op); } bool CompactProtocolReader::read(RowGroup* r) { - auto op = std::make_tuple(ParquetFieldStructList(1, r->columns), - ParquetFieldInt64(2, r->total_byte_size), - ParquetFieldInt64(3, r->num_rows)); + auto op = std::make_tuple(parquet_field_struct_list(1, r->columns), + parquet_field_int64(2, r->total_byte_size), + parquet_field_int64(3, r->num_rows)); return function_builder(this, op); } bool CompactProtocolReader::read(ColumnChunk* c) { - auto op = std::make_tuple(ParquetFieldString(1, c->file_path), - ParquetFieldInt64(2, c->file_offset), - ParquetFieldStruct(3, c->meta_data), - ParquetFieldInt64(4, c->offset_index_offset), - ParquetFieldInt32(5, c->offset_index_length), - ParquetFieldInt64(6, c->column_index_offset), - ParquetFieldInt32(7, c->column_index_length)); + auto op = std::make_tuple(parquet_field_string(1, c->file_path), + parquet_field_int64(2, c->file_offset), + parquet_field_struct(3, c->meta_data), + parquet_field_int64(4, c->offset_index_offset), + parquet_field_int32(5, c->offset_index_length), + parquet_field_int64(6, c->column_index_offset), + parquet_field_int32(7, c->column_index_length)); return function_builder(this, op); } bool CompactProtocolReader::read(ColumnChunkMetaData* c) { - auto op = std::make_tuple(ParquetFieldEnum(1, c->type), - ParquetFieldEnumList(2, c->encodings), - ParquetFieldStringList(3, c->path_in_schema), - ParquetFieldEnum(4, c->codec), - ParquetFieldInt64(5, c->num_values), - ParquetFieldInt64(6, c->total_uncompressed_size), - ParquetFieldInt64(7, c->total_compressed_size), - ParquetFieldInt64(9, c->data_page_offset), - ParquetFieldInt64(10, c->index_page_offset), - ParquetFieldInt64(11, c->dictionary_page_offset), - ParquetFieldStruct(12, c->statistics)); + auto op = std::make_tuple(parquet_field_enum(1, c->type), + parquet_field_enum_list(2, c->encodings), + parquet_field_string_list(3, c->path_in_schema), + parquet_field_enum(4, c->codec), + parquet_field_int64(5, c->num_values), + parquet_field_int64(6, c->total_uncompressed_size), + parquet_field_int64(7, c->total_compressed_size), + parquet_field_int64(9, c->data_page_offset), + parquet_field_int64(10, c->index_page_offset), + parquet_field_int64(11, c->dictionary_page_offset), + parquet_field_struct(12, c->statistics)); return function_builder(this, op); } bool CompactProtocolReader::read(PageHeader* p) { - auto op = std::make_tuple(ParquetFieldEnum(1, p->type), - ParquetFieldInt32(2, p->uncompressed_page_size), - ParquetFieldInt32(3, p->compressed_page_size), - ParquetFieldStruct(5, p->data_page_header), - ParquetFieldStruct(7, p->dictionary_page_header), - ParquetFieldStruct(8, p->data_page_header_v2)); + auto op = std::make_tuple(parquet_field_enum(1, p->type), + parquet_field_int32(2, p->uncompressed_page_size), + parquet_field_int32(3, p->compressed_page_size), + parquet_field_struct(5, p->data_page_header), + parquet_field_struct(7, p->dictionary_page_header), + parquet_field_struct(8, p->data_page_header_v2)); return function_builder(this, op); } bool CompactProtocolReader::read(DataPageHeader* d) { - auto op = std::make_tuple(ParquetFieldInt32(1, d->num_values), - ParquetFieldEnum(2, d->encoding), - ParquetFieldEnum(3, d->definition_level_encoding), - ParquetFieldEnum(4, d->repetition_level_encoding)); + auto op = std::make_tuple(parquet_field_int32(1, d->num_values), + parquet_field_enum(2, d->encoding), + parquet_field_enum(3, d->definition_level_encoding), + parquet_field_enum(4, d->repetition_level_encoding)); return function_builder(this, op); } bool CompactProtocolReader::read(DictionaryPageHeader* d) { - auto op = std::make_tuple(ParquetFieldInt32(1, d->num_values), - ParquetFieldEnum(2, d->encoding)); + auto op = std::make_tuple(parquet_field_int32(1, d->num_values), + parquet_field_enum(2, d->encoding)); return function_builder(this, op); } bool CompactProtocolReader::read(DataPageHeaderV2* d) { - auto op = std::make_tuple(ParquetFieldInt32(1, d->num_values), - ParquetFieldInt32(2, d->num_nulls), - ParquetFieldInt32(3, d->num_rows), - ParquetFieldEnum(4, d->encoding), - ParquetFieldInt32(5, d->definition_levels_byte_length), - ParquetFieldInt32(6, d->repetition_levels_byte_length), - ParquetFieldBool(7, d->is_compressed)); + auto op = std::make_tuple(parquet_field_int32(1, d->num_values), + parquet_field_int32(2, d->num_nulls), + parquet_field_int32(3, d->num_rows), + parquet_field_enum(4, d->encoding), + parquet_field_int32(5, d->definition_levels_byte_length), + parquet_field_int32(6, d->repetition_levels_byte_length), + parquet_field_bool(7, d->is_compressed)); return function_builder(this, op); } bool CompactProtocolReader::read(KeyValue* k) { - auto op = std::make_tuple(ParquetFieldString(1, k->key), ParquetFieldString(2, k->value)); + auto op = std::make_tuple(parquet_field_string(1, k->key), parquet_field_string(2, k->value)); return function_builder(this, op); } bool CompactProtocolReader::read(PageLocation* p) { - auto op = std::make_tuple(ParquetFieldInt64(1, p->offset), - ParquetFieldInt32(2, p->compressed_page_size), - ParquetFieldInt64(3, p->first_row_index)); + auto op = std::make_tuple(parquet_field_int64(1, p->offset), + parquet_field_int32(2, p->compressed_page_size), + parquet_field_int64(3, p->first_row_index)); return function_builder(this, op); } bool CompactProtocolReader::read(OffsetIndex* o) { - auto op = std::make_tuple(ParquetFieldStructList(1, o->page_locations)); + auto op = std::make_tuple(parquet_field_struct_list(1, o->page_locations)); return function_builder(this, op); } bool CompactProtocolReader::read(ColumnIndex* c) { - auto op = std::make_tuple(ParquetFieldBoolList(1, c->null_pages), - ParquetFieldBinaryList(2, c->min_values), - ParquetFieldBinaryList(3, c->max_values), - ParquetFieldEnum(4, c->boundary_order), - ParquetFieldInt64List(5, c->null_counts)); + auto op = std::make_tuple(parquet_field_bool_list(1, c->null_pages), + parquet_field_binary_list(2, c->min_values), + parquet_field_binary_list(3, c->max_values), + parquet_field_enum(4, c->boundary_order), + parquet_field_int64_list(5, c->null_counts)); return function_builder(this, op); } bool CompactProtocolReader::read(Statistics* s) { - auto op = std::make_tuple(ParquetFieldBinary(1, s->max), - ParquetFieldBinary(2, s->min), - ParquetFieldInt64(3, s->null_count), - ParquetFieldInt64(4, s->distinct_count), - ParquetFieldBinary(5, s->max_value), - ParquetFieldBinary(6, s->min_value)); + auto op = std::make_tuple(parquet_field_binary(1, s->max), + parquet_field_binary(2, s->min), + parquet_field_int64(3, s->null_count), + parquet_field_int64(4, s->distinct_count), + parquet_field_binary(5, s->max_value), + parquet_field_binary(6, s->min_value)); + return function_builder(this, op); +} + +bool CompactProtocolReader::read(ColumnOrder* c) +{ + auto op = std::make_tuple(parquet_field_union_enumerator(1, c->type)); return function_builder(this, op); } @@ -338,7 +793,7 @@ bool CompactProtocolReader::read(Statistics* s) */ bool CompactProtocolReader::InitSchema(FileMetaData* md) { - if (static_cast(WalkSchema(md)) != md->schema.size()) return false; + if (static_cast(WalkSchema(md)) != md->schema.size()) { return false; } /* Inside FileMetaData, there is a std::vector of RowGroups and each RowGroup contains a * a std::vector of ColumnChunks. Each ColumnChunk has a member ColumnMetaData, which contains @@ -353,13 +808,15 @@ bool CompactProtocolReader::InitSchema(FileMetaData* md) for (auto const& path : column.meta_data.path_in_schema) { auto const it = [&] { // find_if starting at (current_schema_index + 1) and then wrapping - auto schema = [&](auto const& e) { return e.parent_idx == parent && e.name == path; }; - auto mid = md->schema.cbegin() + current_schema_index + 1; - auto it = std::find_if(mid, md->schema.cend(), schema); - if (it != md->schema.cend()) return it; + auto const schema = [&](auto const& e) { + return e.parent_idx == parent && e.name == path; + }; + auto const mid = md->schema.cbegin() + current_schema_index + 1; + auto const it = std::find_if(mid, md->schema.cend(), schema); + if (it != md->schema.cend()) { return it; } return std::find_if(md->schema.cbegin(), mid, schema); }(); - if (it == md->schema.cend()) return false; + if (it == md->schema.cend()) { return false; } current_schema_index = std::distance(md->schema.cbegin(), it); column.schema_idx = current_schema_index; parent = current_schema_index; @@ -401,9 +858,9 @@ int CompactProtocolReader::WalkSchema( if (e->num_children > 0) { for (int i = 0; i < e->num_children; i++) { e->children_idx.push_back(idx); - int idx_old = idx; - idx = WalkSchema(md, idx, parent_idx, max_def_level, max_rep_level); - if (idx <= idx_old) break; // Error + int const idx_old = idx; + idx = WalkSchema(md, idx, parent_idx, max_def_level, max_rep_level); + if (idx <= idx_old) { break; } // Error } } return idx; diff --git a/cpp/src/io/parquet/compact_protocol_reader.hpp b/cpp/src/io/parquet/compact_protocol_reader.hpp index 62ccacaac37..619815db503 100644 --- a/cpp/src/io/parquet/compact_protocol_reader.hpp +++ b/cpp/src/io/parquet/compact_protocol_reader.hpp @@ -22,6 +22,7 @@ #include #include #include +#include #include namespace cudf { @@ -40,9 +41,6 @@ namespace parquet { * compression codecs are supported yet. */ class CompactProtocolReader { - protected: - static const uint8_t g_list2struct[16]; - public: explicit CompactProtocolReader(uint8_t const* base = nullptr, size_t len = 0) { init(base, len); } void init(uint8_t const* base, size_t len) @@ -57,45 +55,46 @@ class CompactProtocolReader { bytecnt = std::min(bytecnt, (size_t)(m_end - m_cur)); m_cur += bytecnt; } - uint32_t get_u32() noexcept + + // returns a varint encoded integer + template + T get_varint() noexcept { - uint32_t v = 0; + T v = 0; for (uint32_t l = 0;; l += 7) { - uint32_t c = getb(); + T c = getb(); v |= (c & 0x7f) << l; - if (c < 0x80) break; + if (c < 0x80) { break; } } return v; } - uint64_t get_u64() noexcept - { - uint64_t v = 0; - for (uint64_t l = 0;; l += 7) { - uint64_t c = getb(); - v |= (c & 0x7f) << l; - if (c < 0x80) break; - } - return v; - } - int32_t get_i16() noexcept { return get_i32(); } - int32_t get_i32() noexcept - { - uint32_t u = get_u32(); - return (int32_t)((u >> 1u) ^ -(int32_t)(u & 1)); - } - int64_t get_i64() noexcept + + // returns a zigzag encoded signed integer + template + T get_zigzag() noexcept { - uint64_t u = get_u64(); - return (int64_t)((u >> 1u) ^ -(int64_t)(u & 1)); + using U = std::make_unsigned_t; + U const u = get_varint(); + return static_cast((u >> 1u) ^ -static_cast(u & 1)); } - int32_t get_listh(uint8_t* el_type) noexcept + + // thrift spec says to use zigzag i32 for i16 types + int32_t get_i16() noexcept { return get_zigzag(); } + int32_t get_i32() noexcept { return get_zigzag(); } + int64_t get_i64() noexcept { return get_zigzag(); } + + uint32_t get_u32() noexcept { return get_varint(); } + uint64_t get_u64() noexcept { return get_varint(); } + + [[nodiscard]] std::pair get_listh() noexcept { - uint32_t c = getb(); - int32_t sz = c >> 4; - *el_type = c & 0xf; - if (sz == 0xf) sz = get_u32(); - return sz; + uint32_t const c = getb(); + uint32_t sz = c >> 4; + uint8_t t = c & 0xf; + if (sz == 0xf) { sz = get_u32(); } + return {t, sz}; } + bool skip_struct_field(int t, int depth = 0); public: @@ -120,6 +119,7 @@ class CompactProtocolReader { bool read(OffsetIndex* o); bool read(ColumnIndex* c); bool read(Statistics* s); + bool read(ColumnOrder* c); public: static int NumRequiredBits(uint32_t max_level) noexcept @@ -140,523 +140,11 @@ class CompactProtocolReader { uint8_t const* m_cur = nullptr; uint8_t const* m_end = nullptr; - friend class ParquetFieldBool; - friend class ParquetFieldBoolList; - friend class ParquetFieldInt8; - friend class ParquetFieldInt32; - friend class ParquetFieldOptionalInt32; - friend class ParquetFieldInt64; - friend class ParquetFieldInt64List; - template - friend class ParquetFieldStructListFunctor; - friend class ParquetFieldString; - template - friend class ParquetFieldStructFunctor; - template - friend class ParquetFieldUnionFunctor; - template - friend class ParquetFieldEnum; - template - friend class ParquetFieldEnumListFunctor; - friend class ParquetFieldStringList; - friend class ParquetFieldBinary; - friend class ParquetFieldBinaryList; - friend class ParquetFieldStructBlob; -}; - -/** - * @brief Functor to set value to bool read from CompactProtocolReader - * - * @return True if field type is not bool - */ -class ParquetFieldBool { - int field_val; - bool& val; - - public: - ParquetFieldBool(int f, bool& v) : field_val(f), val(v) {} - - inline bool operator()(CompactProtocolReader* cpr, int field_type) - { - return (field_type != ST_FLD_TRUE && field_type != ST_FLD_FALSE) || - !(val = (field_type == ST_FLD_TRUE), true); - } - - int field() { return field_val; } -}; - -/** - * @brief Functor to read a vector of booleans from CompactProtocolReader - * - * @return True if field types mismatch or if the process of reading a - * bool fails - */ -class ParquetFieldBoolList { - int field_val; - std::vector& val; - - public: - ParquetFieldBoolList(int f, std::vector& v) : field_val(f), val(v) {} - inline bool operator()(CompactProtocolReader* cpr, int field_type) - { - if (field_type != ST_FLD_LIST) return true; - uint8_t t; - int32_t n = cpr->get_listh(&t); - if (t != ST_FLD_TRUE) return true; - val.resize(n); - for (int32_t i = 0; i < n; i++) { - unsigned int current_byte = cpr->getb(); - if (current_byte != ST_FLD_TRUE && current_byte != ST_FLD_FALSE) return true; - val[i] = current_byte == ST_FLD_TRUE; - } - return false; - } - - int field() { return field_val; } -}; - -/** - * @brief Functor to set value to 8 bit integer read from CompactProtocolReader - * - * @return True if field type is not int8 - */ -class ParquetFieldInt8 { - int field_val; - int8_t& val; - - public: - ParquetFieldInt8(int f, int8_t& v) : field_val(f), val(v) {} - - inline bool operator()(CompactProtocolReader* cpr, int field_type) - { - val = cpr->getb(); - return (field_type != ST_FLD_BYTE); - } - - int field() { return field_val; } -}; - -/** - * @brief Functor to set value to 32 bit integer read from CompactProtocolReader - * - * @return True if field type is not int32 - */ -class ParquetFieldInt32 { - int field_val; - int32_t& val; - - public: - ParquetFieldInt32(int f, int32_t& v) : field_val(f), val(v) {} - - inline bool operator()(CompactProtocolReader* cpr, int field_type) - { - val = cpr->get_i32(); - return (field_type != ST_FLD_I32); - } - - int field() { return field_val; } -}; - -/** - * @brief Functor to set value to optional 32 bit integer read from CompactProtocolReader - * - * @return True if field type is not int32 - */ -class ParquetFieldOptionalInt32 { - int field_val; - std::optional& val; - - public: - ParquetFieldOptionalInt32(int f, std::optional& v) : field_val(f), val(v) {} - - inline bool operator()(CompactProtocolReader* cpr, int field_type) - { - val = cpr->get_i32(); - return (field_type != ST_FLD_I32); - } - - int field() { return field_val; } -}; - -/** - * @brief Functor to set value to 64 bit integer read from CompactProtocolReader - * - * @return True if field type is not int32 or int64 - */ -class ParquetFieldInt64 { - int field_val; - int64_t& val; - - public: - ParquetFieldInt64(int f, int64_t& v) : field_val(f), val(v) {} - - inline bool operator()(CompactProtocolReader* cpr, int field_type) - { - val = cpr->get_i64(); - return (field_type < ST_FLD_I16 || field_type > ST_FLD_I64); - } - - int field() { return field_val; } -}; - -/** - * @brief Functor to read a vector of 64-bit integers from CompactProtocolReader - * - * @return True if field types mismatch or if the process of reading an - * int64 fails - */ -class ParquetFieldInt64List { - int field_val; - std::vector& val; - - public: - ParquetFieldInt64List(int f, std::vector& v) : field_val(f), val(v) {} - inline bool operator()(CompactProtocolReader* cpr, int field_type) - { - if (field_type != ST_FLD_LIST) return true; - uint8_t t; - int32_t n = cpr->get_listh(&t); - if (t != ST_FLD_I64) return true; - val.resize(n); - for (int32_t i = 0; i < n; i++) { - val[i] = cpr->get_i64(); - } - return false; - } - - int field() { return field_val; } -}; - -/** - * @brief Functor to read a vector of structures from CompactProtocolReader - * - * @return True if field types mismatch or if the process of reading a - * struct fails - */ -template -class ParquetFieldStructListFunctor { - int field_val; - std::vector& val; - - public: - ParquetFieldStructListFunctor(int f, std::vector& v) : field_val(f), val(v) {} - - inline bool operator()(CompactProtocolReader* cpr, int field_type) - { - if (field_type != ST_FLD_LIST) return true; - - int current_byte = cpr->getb(); - if ((current_byte & 0xf) != ST_FLD_STRUCT) return true; - int n = current_byte >> 4; - if (n == 0xf) n = cpr->get_u32(); - val.resize(n); - for (int32_t i = 0; i < n; i++) { - if (!(cpr->read(&val[i]))) { return true; } - } - - return false; - } - - int field() { return field_val; } -}; - -template -ParquetFieldStructListFunctor ParquetFieldStructList(int f, std::vector& v) -{ - return ParquetFieldStructListFunctor(f, v); -} - -/** - * @brief Functor to read a string from CompactProtocolReader - * - * @return True if field type mismatches or if size of string exceeds bounds - * of the CompactProtocolReader - */ -class ParquetFieldString { - int field_val; - std::string& val; - - public: - ParquetFieldString(int f, std::string& v) : field_val(f), val(v) {} - - inline bool operator()(CompactProtocolReader* cpr, int field_type) - { - if (field_type != ST_FLD_BINARY) return true; - uint32_t n = cpr->get_u32(); - if (n < (size_t)(cpr->m_end - cpr->m_cur)) { - val.assign((char const*)cpr->m_cur, n); - cpr->m_cur += n; - return false; - } else { - return true; - } - } - - int field() { return field_val; } -}; - -/** - * @brief Functor to read a structure from CompactProtocolReader - * - * @return True if field types mismatch or if the process of reading a - * struct fails - */ -template -class ParquetFieldStructFunctor { - int field_val; - T& val; - - public: - ParquetFieldStructFunctor(int f, T& v) : field_val(f), val(v) {} - - inline bool operator()(CompactProtocolReader* cpr, int field_type) - { - return (field_type != ST_FLD_STRUCT || !(cpr->read(&val))); - } - - int field() { return field_val; } -}; - -template -ParquetFieldStructFunctor ParquetFieldStruct(int f, T& v) -{ - return ParquetFieldStructFunctor(f, v); -} - -/** - * @brief Functor to read a union member from CompactProtocolReader - * - * @tparam is_empty True if tparam `T` type is empty type, else false. - * - * @return True if field types mismatch or if the process of reading a - * union member fails - */ -template -class ParquetFieldUnionFunctor { - int field_val; - bool& is_set; - T& val; - - public: - ParquetFieldUnionFunctor(int f, bool& b, T& v) : field_val(f), is_set(b), val(v) {} - - inline bool operator()(CompactProtocolReader* cpr, int field_type) - { - if (field_type != ST_FLD_STRUCT) { - return true; - } else { - is_set = true; - return !cpr->read(&val); - } - } - - int field() { return field_val; } -}; - -template -struct ParquetFieldUnionFunctor { - int field_val; - bool& is_set; - T& val; - - public: - ParquetFieldUnionFunctor(int f, bool& b, T& v) : field_val(f), is_set(b), val(v) {} - - inline bool operator()(CompactProtocolReader* cpr, int field_type) - { - if (field_type != ST_FLD_STRUCT) { - return true; - } else { - is_set = true; - cpr->skip_struct_field(field_type); - return false; - } - } - - int field() { return field_val; } -}; - -template -ParquetFieldUnionFunctor> ParquetFieldUnion(int f, bool& b, T& v) -{ - return ParquetFieldUnionFunctor>(f, b, v); -} - -/** - * @brief Functor to set value to enum read from CompactProtocolReader - * - * @return True if field type is not int32 - */ -template -class ParquetFieldEnum { - int field_val; - Enum& val; - - public: - ParquetFieldEnum(int f, Enum& v) : field_val(f), val(v) {} - inline bool operator()(CompactProtocolReader* cpr, int field_type) - { - val = static_cast(cpr->get_i32()); - return (field_type != ST_FLD_I32); - } - - int field() { return field_val; } -}; - -/** - * @brief Functor to read a vector of enums from CompactProtocolReader - * - * @return True if field types mismatch or if the process of reading an - * enum fails - */ -template -class ParquetFieldEnumListFunctor { - int field_val; - std::vector& val; - - public: - ParquetFieldEnumListFunctor(int f, std::vector& v) : field_val(f), val(v) {} - inline bool operator()(CompactProtocolReader* cpr, int field_type) - { - if (field_type != ST_FLD_LIST) return true; - int current_byte = cpr->getb(); - if ((current_byte & 0xf) != ST_FLD_I32) return true; - int n = current_byte >> 4; - if (n == 0xf) n = cpr->get_u32(); - val.resize(n); - for (int32_t i = 0; i < n; i++) { - val[i] = static_cast(cpr->get_i32()); - } - return false; - } - - int field() { return field_val; } -}; - -template -ParquetFieldEnumListFunctor ParquetFieldEnumList(int field, std::vector& v) -{ - return ParquetFieldEnumListFunctor(field, v); -} - -/** - * @brief Functor to read a vector of strings from CompactProtocolReader - * - * @return True if field types mismatch or if the process of reading a - * string fails - */ -class ParquetFieldStringList { - int field_val; - std::vector& val; - - public: - ParquetFieldStringList(int f, std::vector& v) : field_val(f), val(v) {} - inline bool operator()(CompactProtocolReader* cpr, int field_type) - { - if (field_type != ST_FLD_LIST) return true; - uint8_t t; - int32_t n = cpr->get_listh(&t); - if (t != ST_FLD_BINARY) return true; - val.resize(n); - for (int32_t i = 0; i < n; i++) { - uint32_t l = cpr->get_u32(); - if (l < (size_t)(cpr->m_end - cpr->m_cur)) { - val[i].assign((char const*)cpr->m_cur, l); - cpr->m_cur += l; - } else - return true; - } - return false; - } - - int field() { return field_val; } -}; - -/** - * @brief Functor to read a binary from CompactProtocolReader - * - * @return True if field type mismatches or if size of binary exceeds bounds - * of the CompactProtocolReader - */ -class ParquetFieldBinary { - int field_val; - std::vector& val; - - public: - ParquetFieldBinary(int f, std::vector& v) : field_val(f), val(v) {} - - inline bool operator()(CompactProtocolReader* cpr, int field_type) - { - if (field_type != ST_FLD_BINARY) return true; - uint32_t n = cpr->get_u32(); - if (n <= (size_t)(cpr->m_end - cpr->m_cur)) { - val.resize(n); - val.assign(cpr->m_cur, cpr->m_cur + n); - cpr->m_cur += n; - return false; - } else { - return true; - } - } - - int field() { return field_val; } -}; - -/** - * @brief Functor to read a vector of binaries from CompactProtocolReader - * - * @return True if field types mismatch or if the process of reading a - * binary fails - */ -class ParquetFieldBinaryList { - int field_val; - std::vector>& val; - - public: - ParquetFieldBinaryList(int f, std::vector>& v) : field_val(f), val(v) {} - inline bool operator()(CompactProtocolReader* cpr, int field_type) - { - if (field_type != ST_FLD_LIST) return true; - uint8_t t; - int32_t n = cpr->get_listh(&t); - if (t != ST_FLD_BINARY) return true; - val.resize(n); - for (int32_t i = 0; i < n; i++) { - uint32_t l = cpr->get_u32(); - if (l <= (size_t)(cpr->m_end - cpr->m_cur)) { - val[i].resize(l); - val[i].assign(cpr->m_cur, cpr->m_cur + l); - cpr->m_cur += l; - } else - return true; - } - return false; - } - - int field() { return field_val; } -}; - -/** - * @brief Functor to read a struct from CompactProtocolReader - * - * @return True if field type mismatches - */ -class ParquetFieldStructBlob { - int field_val; - std::vector& val; - - public: - ParquetFieldStructBlob(int f, std::vector& v) : field_val(f), val(v) {} - inline bool operator()(CompactProtocolReader* cpr, int field_type) - { - if (field_type != ST_FLD_STRUCT) return true; - uint8_t const* start = cpr->m_cur; - cpr->skip_struct_field(field_type); - if (cpr->m_cur > start) { val.assign(start, cpr->m_cur - 1); } - return false; - } - - int field() { return field_val; } + friend class parquet_field_string; + friend class parquet_field_string_list; + friend class parquet_field_binary; + friend class parquet_field_binary_list; + friend class parquet_field_struct_blob; }; } // namespace parquet diff --git a/cpp/src/io/parquet/compact_protocol_writer.cpp b/cpp/src/io/parquet/compact_protocol_writer.cpp index b2c0c97c52d..60bc8984d81 100644 --- a/cpp/src/io/parquet/compact_protocol_writer.cpp +++ b/cpp/src/io/parquet/compact_protocol_writer.cpp @@ -33,18 +33,7 @@ size_t CompactProtocolWriter::write(FileMetaData const& f) c.field_struct_list(4, f.row_groups); if (not f.key_value_metadata.empty()) { c.field_struct_list(5, f.key_value_metadata); } if (not f.created_by.empty()) { c.field_string(6, f.created_by); } - if (f.column_order_listsize != 0) { - // Dummy list of struct containing an empty field1 struct - c.put_field_header(7, c.current_field(), ST_FLD_LIST); - c.put_byte((uint8_t)((std::min(f.column_order_listsize, 0xfu) << 4) | ST_FLD_STRUCT)); - if (f.column_order_listsize >= 0xf) c.put_uint(f.column_order_listsize); - for (uint32_t i = 0; i < f.column_order_listsize; i++) { - c.put_field_header(1, 0, ST_FLD_STRUCT); - c.put_byte(0); // ColumnOrder.field1 struct end - c.put_byte(0); // ColumnOrder struct end - } - c.set_current_field(7); - } + if (f.column_orders.has_value()) { c.field_struct_list(7, f.column_orders.value()); } return c.value(); } @@ -233,6 +222,16 @@ size_t CompactProtocolWriter::write(OffsetIndex const& s) return c.value(); } +size_t CompactProtocolWriter::write(ColumnOrder const& co) +{ + CompactProtocolFieldWriter c(*this); + switch (co) { + case ColumnOrder::TYPE_ORDER: c.field_empty_struct(1); break; + default: break; + } + return c.value(); +} + void CompactProtocolFieldWriter::put_byte(uint8_t v) { writer.m_buf.push_back(v); } void CompactProtocolFieldWriter::put_byte(uint8_t const* raw, uint32_t len) @@ -320,6 +319,13 @@ inline void CompactProtocolFieldWriter::field_struct(int field, T const& val) current_field_value = field; } +inline void CompactProtocolFieldWriter::field_empty_struct(int field) +{ + put_field_header(field, current_field_value, ST_FLD_STRUCT); + put_byte(0); // add a stop field + current_field_value = field; +} + template inline void CompactProtocolFieldWriter::field_struct_list(int field, std::vector const& val) { diff --git a/cpp/src/io/parquet/compact_protocol_writer.hpp b/cpp/src/io/parquet/compact_protocol_writer.hpp index 8d7b0961934..26d66527aa5 100644 --- a/cpp/src/io/parquet/compact_protocol_writer.hpp +++ b/cpp/src/io/parquet/compact_protocol_writer.hpp @@ -53,6 +53,7 @@ class CompactProtocolWriter { size_t write(Statistics const&); size_t write(PageLocation const&); size_t write(OffsetIndex const&); + size_t write(ColumnOrder const&); protected: std::vector& m_buf; @@ -94,6 +95,8 @@ class CompactProtocolFieldWriter { template inline void field_struct(int field, T const& val); + inline void field_empty_struct(int field); + template inline void field_struct_list(int field, std::vector const& val); diff --git a/cpp/src/io/parquet/parquet.hpp b/cpp/src/io/parquet/parquet.hpp index f7318bb9935..c2affc774c2 100644 --- a/cpp/src/io/parquet/parquet.hpp +++ b/cpp/src/io/parquet/parquet.hpp @@ -18,6 +18,8 @@ #include "parquet_common.hpp" +#include + #include #include #include @@ -118,6 +120,16 @@ struct LogicalType { BsonType BSON; }; +/** + * Union to specify the order used for the min_value and max_value fields for a column. + */ +struct ColumnOrder { + enum Type { UNDEFINED, TYPE_ORDER }; + Type type; + + operator Type() const { return type; } +}; + /** * @brief Struct for describing an element/field in the Parquet format schema * @@ -135,7 +147,7 @@ struct SchemaElement { int32_t num_children = 0; int32_t decimal_scale = 0; int32_t decimal_precision = 0; - std::optional field_id = std::nullopt; + thrust::optional field_id = thrust::nullopt; bool output_as_byte_array = false; // The following fields are filled in later during schema initialization @@ -284,8 +296,8 @@ struct FileMetaData { int64_t num_rows = 0; std::vector row_groups; std::vector key_value_metadata; - std::string created_by = ""; - uint32_t column_order_listsize = 0; + std::string created_by = ""; + thrust::optional> column_orders; }; /** diff --git a/cpp/src/io/parquet/parquet_common.hpp b/cpp/src/io/parquet/parquet_common.hpp index 5f8f1617cb9..5a1716bb547 100644 --- a/cpp/src/io/parquet/parquet_common.hpp +++ b/cpp/src/io/parquet/parquet_common.hpp @@ -141,7 +141,7 @@ enum BoundaryOrder { /** * @brief Thrift compact protocol struct field types */ -enum { +enum FieldType { ST_FLD_TRUE = 1, ST_FLD_FALSE = 2, ST_FLD_BYTE = 3, diff --git a/cpp/src/io/parquet/writer_impl.cu b/cpp/src/io/parquet/writer_impl.cu index d2976a3f5d9..a124f352ee4 100644 --- a/cpp/src/io/parquet/writer_impl.cu +++ b/cpp/src/io/parquet/writer_impl.cu @@ -74,8 +74,11 @@ struct aggregate_writer_metadata { for (size_t i = 0; i < partitions.size(); ++i) { this->files[i].num_rows = partitions[i].num_rows; } - this->column_order_listsize = - (stats_granularity != statistics_freq::STATISTICS_NONE) ? num_columns : 0; + + if (stats_granularity != statistics_freq::STATISTICS_NONE) { + ColumnOrder default_order = {ColumnOrder::TYPE_ORDER}; + this->column_orders = std::vector(num_columns, default_order); + } for (size_t p = 0; p < kv_md.size(); ++p) { std::transform(kv_md[p].begin(), @@ -102,13 +105,13 @@ struct aggregate_writer_metadata { { CUDF_EXPECTS(part < files.size(), "Invalid part index queried"); FileMetaData meta{}; - meta.version = this->version; - meta.schema = this->schema; - meta.num_rows = this->files[part].num_rows; - meta.row_groups = this->files[part].row_groups; - meta.key_value_metadata = this->files[part].key_value_metadata; - meta.created_by = this->created_by; - meta.column_order_listsize = this->column_order_listsize; + meta.version = this->version; + meta.schema = this->schema; + meta.num_rows = this->files[part].num_rows; + meta.row_groups = this->files[part].row_groups; + meta.key_value_metadata = this->files[part].key_value_metadata; + meta.created_by = this->created_by; + meta.column_orders = this->column_orders; return meta; } @@ -170,8 +173,8 @@ struct aggregate_writer_metadata { std::vector> column_indexes; }; std::vector files; - std::string created_by = ""; - uint32_t column_order_listsize = 0; + std::string created_by = ""; + thrust::optional> column_orders = thrust::nullopt; }; namespace { @@ -2373,20 +2376,7 @@ std::unique_ptr> writer::merge_row_group_metadata( md.num_rows += tmp.num_rows; } } - // Reader doesn't currently populate column_order, so infer it here - if (not md.row_groups.empty()) { - auto const is_valid_stats = [](auto const& stats) { - return not stats.max.empty() || not stats.min.empty() || stats.null_count != -1 || - stats.distinct_count != -1 || not stats.max_value.empty() || - not stats.min_value.empty(); - }; - uint32_t num_columns = static_cast(md.row_groups[0].columns.size()); - md.column_order_listsize = - (num_columns > 0 && is_valid_stats(md.row_groups[0].columns[0].meta_data.statistics)) - ? num_columns - : 0; - } // Thrift-encode the resulting output file_header_s fhdr; file_ender_s fendr;