From 253f6a6d5b19387c05368e073954ff773b3d6a39 Mon Sep 17 00:00:00 2001 From: Ed Seidl Date: Fri, 20 Oct 2023 14:05:53 -0700 Subject: [PATCH] Refactor LogicalType for Parquet (#14264) Continuation of #14097, this PR refactors the LogicalType struct to use the new way of treating unions defined in the parquet thrift (more enum like than struct like). Authors: - Ed Seidl (https://github.com/etseidl) - Vukasin Milovanovic (https://github.com/vuule) Approvers: - Vukasin Milovanovic (https://github.com/vuule) - Nghia Truong (https://github.com/ttnghia) URL: https://github.com/rapidsai/cudf/pull/14264 --- .../io/parquet/compact_protocol_reader.cpp | 95 +++-------- .../io/parquet/compact_protocol_writer.cpp | 81 +++++---- cpp/src/io/parquet/page_decode.cuh | 3 +- cpp/src/io/parquet/parquet.hpp | 156 +++++++++++------- cpp/src/io/parquet/parquet_gpu.hpp | 30 ++-- cpp/src/io/parquet/reader_impl_chunking.cu | 13 +- cpp/src/io/parquet/reader_impl_helpers.cpp | 104 ++++++------ cpp/src/io/parquet/writer_impl.cu | 107 +++++++----- cpp/tests/io/parquet_test.cpp | 7 +- 9 files changed, 293 insertions(+), 303 deletions(-) diff --git a/cpp/src/io/parquet/compact_protocol_reader.cpp b/cpp/src/io/parquet/compact_protocol_reader.cpp index 1a345ee0750..5a2b8aa8f2a 100644 --- a/cpp/src/io/parquet/compact_protocol_reader.cpp +++ b/cpp/src/io/parquet/compact_protocol_reader.cpp @@ -339,61 +339,6 @@ struct parquet_field_struct_list : public parquet_field_list { } }; -// TODO(ets): replace current union handling (which mirrors thrift) to use std::optional fields -// in a struct -/** - * @brief Functor to read a union member from CompactProtocolReader - * - * @tparam is_empty True if tparam `T` type is empty type, else false. - * - * @return True if field types mismatch or if the process of reading a - * union member fails - */ -template -class ParquetFieldUnionFunctor : public parquet_field { - bool& is_set; - T& val; - - public: - ParquetFieldUnionFunctor(int f, bool& b, T& v) : parquet_field(f), is_set(b), val(v) {} - - inline bool operator()(CompactProtocolReader* cpr, int field_type) - { - if (field_type != ST_FLD_STRUCT) { - return true; - } else { - is_set = true; - return !cpr->read(&val); - } - } -}; - -template -class ParquetFieldUnionFunctor : public parquet_field { - bool& is_set; - T& val; - - public: - ParquetFieldUnionFunctor(int f, bool& b, T& v) : parquet_field(f), is_set(b), val(v) {} - - inline bool operator()(CompactProtocolReader* cpr, int field_type) - { - if (field_type != ST_FLD_STRUCT) { - return true; - } else { - is_set = true; - cpr->skip_struct_field(field_type); - return false; - } - } -}; - -template -ParquetFieldUnionFunctor> ParquetFieldUnion(int f, bool& b, T& v) -{ - return ParquetFieldUnionFunctor>(f, b, v); -} - /** * @brief Functor to read a binary from CompactProtocolReader * @@ -595,34 +540,38 @@ bool CompactProtocolReader::read(FileMetaData* f) bool CompactProtocolReader::read(SchemaElement* s) { + using optional_converted_type = + parquet_field_optional>; + using optional_logical_type = + parquet_field_optional>; auto op = std::make_tuple(parquet_field_enum(1, s->type), parquet_field_int32(2, s->type_length), parquet_field_enum(3, s->repetition_type), parquet_field_string(4, s->name), parquet_field_int32(5, s->num_children), - parquet_field_enum(6, s->converted_type), + optional_converted_type(6, s->converted_type), parquet_field_int32(7, s->decimal_scale), parquet_field_int32(8, s->decimal_precision), parquet_field_optional(9, s->field_id), - parquet_field_struct(10, s->logical_type)); + optional_logical_type(10, s->logical_type)); return function_builder(this, op); } bool CompactProtocolReader::read(LogicalType* l) { - auto op = - std::make_tuple(ParquetFieldUnion(1, l->isset.STRING, l->STRING), - ParquetFieldUnion(2, l->isset.MAP, l->MAP), - ParquetFieldUnion(3, l->isset.LIST, l->LIST), - ParquetFieldUnion(4, l->isset.ENUM, l->ENUM), - ParquetFieldUnion(5, l->isset.DECIMAL, l->DECIMAL), // read the struct - ParquetFieldUnion(6, l->isset.DATE, l->DATE), - ParquetFieldUnion(7, l->isset.TIME, l->TIME), // read the struct - ParquetFieldUnion(8, l->isset.TIMESTAMP, l->TIMESTAMP), // read the struct - ParquetFieldUnion(10, l->isset.INTEGER, l->INTEGER), // read the struct - ParquetFieldUnion(11, l->isset.UNKNOWN, l->UNKNOWN), - ParquetFieldUnion(12, l->isset.JSON, l->JSON), - ParquetFieldUnion(13, l->isset.BSON, l->BSON)); + auto op = std::make_tuple( + parquet_field_union_enumerator(1, l->type), + parquet_field_union_enumerator(2, l->type), + parquet_field_union_enumerator(3, l->type), + parquet_field_union_enumerator(4, l->type), + parquet_field_union_struct(5, l->type, l->decimal_type), + parquet_field_union_enumerator(6, l->type), + parquet_field_union_struct(7, l->type, l->time_type), + parquet_field_union_struct(8, l->type, l->timestamp_type), + parquet_field_union_struct(10, l->type, l->int_type), + parquet_field_union_enumerator(11, l->type), + parquet_field_union_enumerator(12, l->type), + parquet_field_union_enumerator(13, l->type)); return function_builder(this, op); } @@ -648,9 +597,9 @@ bool CompactProtocolReader::read(TimestampType* t) bool CompactProtocolReader::read(TimeUnit* u) { - auto op = std::make_tuple(ParquetFieldUnion(1, u->isset.MILLIS, u->MILLIS), - ParquetFieldUnion(2, u->isset.MICROS, u->MICROS), - ParquetFieldUnion(3, u->isset.NANOS, u->NANOS)); + auto op = std::make_tuple(parquet_field_union_enumerator(1, u->type), + parquet_field_union_enumerator(2, u->type), + parquet_field_union_enumerator(3, u->type)); return function_builder(this, op); } diff --git a/cpp/src/io/parquet/compact_protocol_writer.cpp b/cpp/src/io/parquet/compact_protocol_writer.cpp index 00810269d3c..fbeda7f1099 100644 --- a/cpp/src/io/parquet/compact_protocol_writer.cpp +++ b/cpp/src/io/parquet/compact_protocol_writer.cpp @@ -16,6 +16,8 @@ #include "compact_protocol_writer.hpp" +#include + namespace cudf::io::parquet::detail { /** @@ -46,13 +48,11 @@ size_t CompactProtocolWriter::write(DecimalType const& decimal) size_t CompactProtocolWriter::write(TimeUnit const& time_unit) { CompactProtocolFieldWriter c(*this); - auto const isset = time_unit.isset; - if (isset.MILLIS) { - c.field_struct(1, time_unit.MILLIS); - } else if (isset.MICROS) { - c.field_struct(2, time_unit.MICROS); - } else if (isset.NANOS) { - c.field_struct(3, time_unit.NANOS); + switch (time_unit.type) { + case TimeUnit::MILLIS: + case TimeUnit::MICROS: + case TimeUnit::NANOS: c.field_empty_struct(time_unit.type); break; + default: CUDF_FAIL("Trying to write an invalid TimeUnit " + std::to_string(time_unit.type)); } return c.value(); } @@ -84,31 +84,29 @@ size_t CompactProtocolWriter::write(IntType const& integer) size_t CompactProtocolWriter::write(LogicalType const& logical_type) { CompactProtocolFieldWriter c(*this); - auto const isset = logical_type.isset; - if (isset.STRING) { - c.field_struct(1, logical_type.STRING); - } else if (isset.MAP) { - c.field_struct(2, logical_type.MAP); - } else if (isset.LIST) { - c.field_struct(3, logical_type.LIST); - } else if (isset.ENUM) { - c.field_struct(4, logical_type.ENUM); - } else if (isset.DECIMAL) { - c.field_struct(5, logical_type.DECIMAL); - } else if (isset.DATE) { - c.field_struct(6, logical_type.DATE); - } else if (isset.TIME) { - c.field_struct(7, logical_type.TIME); - } else if (isset.TIMESTAMP) { - c.field_struct(8, logical_type.TIMESTAMP); - } else if (isset.INTEGER) { - c.field_struct(10, logical_type.INTEGER); - } else if (isset.UNKNOWN) { - c.field_struct(11, logical_type.UNKNOWN); - } else if (isset.JSON) { - c.field_struct(12, logical_type.JSON); - } else if (isset.BSON) { - c.field_struct(13, logical_type.BSON); + switch (logical_type.type) { + case LogicalType::STRING: + case LogicalType::MAP: + case LogicalType::LIST: + case LogicalType::ENUM: + case LogicalType::DATE: + case LogicalType::UNKNOWN: + case LogicalType::JSON: + case LogicalType::BSON: c.field_empty_struct(logical_type.type); break; + case LogicalType::DECIMAL: + c.field_struct(LogicalType::DECIMAL, logical_type.decimal_type.value()); + break; + case LogicalType::TIME: + c.field_struct(LogicalType::TIME, logical_type.time_type.value()); + break; + case LogicalType::TIMESTAMP: + c.field_struct(LogicalType::TIMESTAMP, logical_type.timestamp_type.value()); + break; + case LogicalType::INTEGER: + c.field_struct(LogicalType::INTEGER, logical_type.int_type.value()); + break; + default: + CUDF_FAIL("Trying to write an invalid LogicalType " + std::to_string(logical_type.type)); } return c.value(); } @@ -124,20 +122,15 @@ size_t CompactProtocolWriter::write(SchemaElement const& s) c.field_string(4, s.name); if (s.type == UNDEFINED_TYPE) { c.field_int(5, s.num_children); } - if (s.converted_type != UNKNOWN) { - c.field_int(6, s.converted_type); + if (s.converted_type.has_value()) { + c.field_int(6, s.converted_type.value()); if (s.converted_type == DECIMAL) { c.field_int(7, s.decimal_scale); c.field_int(8, s.decimal_precision); } } - if (s.field_id) { c.field_int(9, s.field_id.value()); } - auto const isset = s.logical_type.isset; - // TODO: add handling for all logical types - // if (isset.STRING or isset.MAP or isset.LIST or isset.ENUM or isset.DECIMAL or isset.DATE or - // isset.TIME or isset.TIMESTAMP or isset.INTEGER or isset.UNKNOWN or isset.JSON or isset.BSON) - // { - if (isset.TIMESTAMP or isset.TIME) { c.field_struct(10, s.logical_type); } + if (s.field_id.has_value()) { c.field_int(9, s.field_id.value()); } + if (s.logical_type.has_value()) { c.field_struct(10, s.logical_type.value()); } return c.value(); } @@ -223,9 +216,9 @@ size_t CompactProtocolWriter::write(OffsetIndex const& s) size_t CompactProtocolWriter::write(ColumnOrder const& co) { CompactProtocolFieldWriter c(*this); - switch (co) { - case ColumnOrder::TYPE_ORDER: c.field_empty_struct(1); break; - default: break; + switch (co.type) { + case ColumnOrder::TYPE_ORDER: c.field_empty_struct(co.type); break; + default: CUDF_FAIL("Trying to write an invalid ColumnOrder " + std::to_string(co.type)); } return c.value(); } diff --git a/cpp/src/io/parquet/page_decode.cuh b/cpp/src/io/parquet/page_decode.cuh index 7c866fd8b9e..ab1cc68923d 100644 --- a/cpp/src/io/parquet/page_decode.cuh +++ b/cpp/src/io/parquet/page_decode.cuh @@ -1143,7 +1143,8 @@ inline __device__ bool setupLocalPageInfo(page_state_s* const s, units = cudf::timestamp_ms::period::den; } else if (s->col.converted_type == TIMESTAMP_MICROS) { units = cudf::timestamp_us::period::den; - } else if (s->col.logical_type.TIMESTAMP.unit.isset.NANOS) { + } else if (s->col.logical_type.has_value() and + s->col.logical_type->is_timestamp_nanos()) { units = cudf::timestamp_ns::period::den; } if (units and units != s->col.ts_clock_rate) { diff --git a/cpp/src/io/parquet/parquet.hpp b/cpp/src/io/parquet/parquet.hpp index 1cd16ac6102..699cad89703 100644 --- a/cpp/src/io/parquet/parquet.hpp +++ b/cpp/src/io/parquet/parquet.hpp @@ -46,79 +46,98 @@ struct file_ender_s { uint32_t magic; }; -// thrift generated code simplified. -struct StringType {}; -struct MapType {}; -struct ListType {}; -struct EnumType {}; +// thrift inspired code simplified. struct DecimalType { int32_t scale = 0; int32_t precision = 0; }; -struct DateType {}; - -struct MilliSeconds {}; -struct MicroSeconds {}; -struct NanoSeconds {}; -using TimeUnit_isset = struct TimeUnit_isset { - bool MILLIS{false}; - bool MICROS{false}; - bool NANOS{false}; -}; struct TimeUnit { - TimeUnit_isset isset; - MilliSeconds MILLIS; - MicroSeconds MICROS; - NanoSeconds NANOS; + enum Type { UNDEFINED, MILLIS, MICROS, NANOS }; + Type type; }; struct TimeType { bool isAdjustedToUTC = false; TimeUnit unit; }; + struct TimestampType { bool isAdjustedToUTC = false; TimeUnit unit; }; + struct IntType { int8_t bitWidth = 0; bool isSigned = false; }; -struct NullType {}; -struct JsonType {}; -struct BsonType {}; - -// thrift generated code simplified. -using LogicalType_isset = struct LogicalType_isset { - bool STRING{false}; - bool MAP{false}; - bool LIST{false}; - bool ENUM{false}; - bool DECIMAL{false}; - bool DATE{false}; - bool TIME{false}; - bool TIMESTAMP{false}; - bool INTEGER{false}; - bool UNKNOWN{false}; - bool JSON{false}; - bool BSON{false}; -}; struct LogicalType { - LogicalType_isset isset; - StringType STRING; - MapType MAP; - ListType LIST; - EnumType ENUM; - DecimalType DECIMAL; - DateType DATE; - TimeType TIME; - TimestampType TIMESTAMP; - IntType INTEGER; - NullType UNKNOWN; - JsonType JSON; - BsonType BSON; + enum Type { + UNDEFINED, + STRING, + MAP, + LIST, + ENUM, + DECIMAL, + DATE, + TIME, + TIMESTAMP, + // 9 is reserved + INTEGER = 10, + UNKNOWN, + JSON, + BSON + }; + Type type; + thrust::optional decimal_type; + thrust::optional time_type; + thrust::optional timestamp_type; + thrust::optional int_type; + + LogicalType(Type tp = UNDEFINED) : type(tp) {} + LogicalType(DecimalType&& dt) : type(DECIMAL), decimal_type(dt) {} + LogicalType(TimeType&& tt) : type(TIME), time_type(tt) {} + LogicalType(TimestampType&& tst) : type(TIMESTAMP), timestamp_type(tst) {} + LogicalType(IntType&& it) : type(INTEGER), int_type(it) {} + + constexpr bool is_time_millis() const + { + return type == TIME and time_type->unit.type == TimeUnit::MILLIS; + } + + constexpr bool is_time_micros() const + { + return type == TIME and time_type->unit.type == TimeUnit::MICROS; + } + + constexpr bool is_time_nanos() const + { + return type == TIME and time_type->unit.type == TimeUnit::NANOS; + } + + constexpr bool is_timestamp_millis() const + { + return type == TIMESTAMP and timestamp_type->unit.type == TimeUnit::MILLIS; + } + + constexpr bool is_timestamp_micros() const + { + return type == TIMESTAMP and timestamp_type->unit.type == TimeUnit::MICROS; + } + + constexpr bool is_timestamp_nanos() const + { + return type == TIMESTAMP and timestamp_type->unit.type == TimeUnit::NANOS; + } + + constexpr int8_t bit_width() const { return type == INTEGER ? int_type->bitWidth : -1; } + + constexpr bool is_signed() const { return type == INTEGER and int_type->isSigned; } + + constexpr int32_t scale() const { return type == DECIMAL ? decimal_type->scale : -1; } + + constexpr int32_t precision() const { return type == DECIMAL ? decimal_type->precision : -1; } }; /** @@ -127,8 +146,6 @@ struct LogicalType { struct ColumnOrder { enum Type { UNDEFINED, TYPE_ORDER }; Type type; - - operator Type() const { return type; } }; /** @@ -138,18 +155,29 @@ struct ColumnOrder { * as a schema tree. */ struct SchemaElement { - Type type = UNDEFINED_TYPE; - ConvertedType converted_type = UNKNOWN; - LogicalType logical_type; - int32_t type_length = - 0; // Byte length of FIXED_LENGTH_BYTE_ARRAY elements, or maximum bit length for other types + // 1: parquet physical type for output + Type type = UNDEFINED_TYPE; + // 2: byte length of FIXED_LENGTH_BYTE_ARRAY elements, or maximum bit length for other types + int32_t type_length = 0; + // 3: repetition of the field FieldRepetitionType repetition_type = REQUIRED; - std::string name = ""; - int32_t num_children = 0; - int32_t decimal_scale = 0; - int32_t decimal_precision = 0; - thrust::optional field_id = thrust::nullopt; - bool output_as_byte_array = false; + // 4: name of the field + std::string name = ""; + // 5: nested fields + int32_t num_children = 0; + // 6: DEPRECATED: record the original type before conversion to parquet type + thrust::optional converted_type; + // 7: DEPRECATED: record the scale for DECIMAL converted type + int32_t decimal_scale = 0; + // 8: DEPRECATED: record the precision for DECIMAL converted type + int32_t decimal_precision = 0; + // 9: save field_id from original schema + thrust::optional field_id; + // 10: replaces converted type + thrust::optional logical_type; + + // extra cudf specific fields + bool output_as_byte_array = false; // The following fields are filled in later during schema initialization int max_definition_level = 0; diff --git a/cpp/src/io/parquet/parquet_gpu.hpp b/cpp/src/io/parquet/parquet_gpu.hpp index 164e2cea2ed..68851e72663 100644 --- a/cpp/src/io/parquet/parquet_gpu.hpp +++ b/cpp/src/io/parquet/parquet_gpu.hpp @@ -313,7 +313,7 @@ struct ColumnChunkDesc { uint8_t rep_level_bits_, int8_t codec_, int8_t converted_type_, - LogicalType logical_type_, + thrust::optional logical_type_, int8_t decimal_precision_, int32_t ts_clock_rate_, int32_t src_col_index_, @@ -355,20 +355,20 @@ struct ColumnChunkDesc { uint16_t data_type{}; // basic column data type, ((type_length << 3) | // parquet::Type) uint8_t - level_bits[level_type::NUM_LEVEL_TYPES]{}; // bits to encode max definition/repetition levels - int32_t num_data_pages{}; // number of data pages - int32_t num_dict_pages{}; // number of dictionary pages - int32_t max_num_pages{}; // size of page_info array - PageInfo* page_info{}; // output page info for up to num_dict_pages + - // num_data_pages (dictionary pages first) - string_index_pair* str_dict_index{}; // index for string dictionary - bitmask_type** valid_map_base{}; // base pointers of valid bit map for this column - void** column_data_base{}; // base pointers of column data - void** column_string_base{}; // base pointers of column string data - int8_t codec{}; // compressed codec enum - int8_t converted_type{}; // converted type enum - LogicalType logical_type{}; // logical type - int8_t decimal_precision{}; // Decimal precision + level_bits[level_type::NUM_LEVEL_TYPES]{}; // bits to encode max definition/repetition levels + int32_t num_data_pages{}; // number of data pages + int32_t num_dict_pages{}; // number of dictionary pages + int32_t max_num_pages{}; // size of page_info array + PageInfo* page_info{}; // output page info for up to num_dict_pages + + // num_data_pages (dictionary pages first) + string_index_pair* str_dict_index{}; // index for string dictionary + bitmask_type** valid_map_base{}; // base pointers of valid bit map for this column + void** column_data_base{}; // base pointers of column data + void** column_string_base{}; // base pointers of column string data + int8_t codec{}; // compressed codec enum + int8_t converted_type{}; // converted type enum + thrust::optional logical_type{}; // logical type + int8_t decimal_precision{}; // Decimal precision int32_t ts_clock_rate{}; // output timestamp clock frequency (0=default, 1000=ms, 1000000000=ns) int32_t src_col_index{}; // my input column index diff --git a/cpp/src/io/parquet/reader_impl_chunking.cu b/cpp/src/io/parquet/reader_impl_chunking.cu index ad52a7dfcc1..213fc380a34 100644 --- a/cpp/src/io/parquet/reader_impl_chunking.cu +++ b/cpp/src/io/parquet/reader_impl_chunking.cu @@ -304,11 +304,12 @@ std::vector find_splits(std::vector const& * * @return A tuple of Parquet type width, Parquet clock rate and Parquet decimal type. */ -[[nodiscard]] std::tuple conversion_info(type_id column_type_id, - type_id timestamp_type_id, - Type physical, - int8_t converted, - int32_t length) +[[nodiscard]] std::tuple conversion_info( + type_id column_type_id, + type_id timestamp_type_id, + Type physical, + thrust::optional converted, + int32_t length) { int32_t type_width = (physical == FIXED_LEN_BYTE_ARRAY) ? length : 0; int32_t clock_rate = 0; @@ -322,7 +323,7 @@ std::vector find_splits(std::vector const& clock_rate = to_clockrate(timestamp_type_id); } - int8_t converted_type = converted; + int8_t converted_type = converted.value_or(UNKNOWN); if (converted_type == DECIMAL && column_type_id != type_id::FLOAT64 && not cudf::is_fixed_point(data_type{column_type_id})) { converted_type = UNKNOWN; // Not converting to float64 or decimal diff --git a/cpp/src/io/parquet/reader_impl_helpers.cpp b/cpp/src/io/parquet/reader_impl_helpers.cpp index 040c6403f57..a9c84143e1a 100644 --- a/cpp/src/io/parquet/reader_impl_helpers.cpp +++ b/cpp/src/io/parquet/reader_impl_helpers.cpp @@ -25,44 +25,42 @@ namespace cudf::io::parquet::detail { namespace { -ConvertedType logical_type_to_converted_type(LogicalType const& logical) +ConvertedType logical_type_to_converted_type(thrust::optional const& logical) { - if (logical.isset.STRING) { - return UTF8; - } else if (logical.isset.MAP) { - return MAP; - } else if (logical.isset.LIST) { - return LIST; - } else if (logical.isset.ENUM) { - return ENUM; - } else if (logical.isset.DECIMAL) { - return DECIMAL; // TODO set decimal values - } else if (logical.isset.DATE) { - return DATE; - } else if (logical.isset.TIME) { - if (logical.TIME.unit.isset.MILLIS) - return TIME_MILLIS; - else if (logical.TIME.unit.isset.MICROS) - return TIME_MICROS; - } else if (logical.isset.TIMESTAMP) { - if (logical.TIMESTAMP.unit.isset.MILLIS) - return TIMESTAMP_MILLIS; - else if (logical.TIMESTAMP.unit.isset.MICROS) - return TIMESTAMP_MICROS; - } else if (logical.isset.INTEGER) { - switch (logical.INTEGER.bitWidth) { - case 8: return logical.INTEGER.isSigned ? INT_8 : UINT_8; - case 16: return logical.INTEGER.isSigned ? INT_16 : UINT_16; - case 32: return logical.INTEGER.isSigned ? INT_32 : UINT_32; - case 64: return logical.INTEGER.isSigned ? INT_64 : UINT_64; - default: break; - } - } else if (logical.isset.UNKNOWN) { - return NA; - } else if (logical.isset.JSON) { - return JSON; - } else if (logical.isset.BSON) { - return BSON; + if (not logical.has_value()) { return UNKNOWN; } + switch (logical->type) { + case LogicalType::STRING: return UTF8; + case LogicalType::MAP: return MAP; + case LogicalType::LIST: return LIST; + case LogicalType::ENUM: return ENUM; + case LogicalType::DECIMAL: return DECIMAL; // TODO use decimal scale/precision + case LogicalType::DATE: return DATE; + case LogicalType::TIME: + if (logical->is_time_millis()) { + return TIME_MILLIS; + } else if (logical->is_time_micros()) { + return TIME_MICROS; + } + break; + case LogicalType::TIMESTAMP: + if (logical->is_timestamp_millis()) { + return TIMESTAMP_MILLIS; + } else if (logical->is_timestamp_micros()) { + return TIMESTAMP_MICROS; + } + break; + case LogicalType::INTEGER: + switch (logical->bit_width()) { + case 8: return logical->is_signed() ? INT_8 : UINT_8; + case 16: return logical->is_signed() ? INT_16 : UINT_16; + case 32: return logical->is_signed() ? INT_32 : UINT_32; + case 64: return logical->is_signed() ? INT_64 : UINT_64; + default: break; + } + case LogicalType::UNKNOWN: return NA; + case LogicalType::JSON: return JSON; + case LogicalType::BSON: return BSON; + default: break; } return UNKNOWN; } @@ -76,20 +74,20 @@ type_id to_type_id(SchemaElement const& schema, bool strings_to_categorical, type_id timestamp_type_id) { - Type const physical = schema.type; - LogicalType const logical_type = schema.logical_type; - ConvertedType converted_type = schema.converted_type; - int32_t decimal_precision = schema.decimal_precision; + auto const physical = schema.type; + auto const logical_type = schema.logical_type; + auto converted_type = schema.converted_type; + int32_t decimal_precision = schema.decimal_precision; + // FIXME(ets): this should just use logical type to deduce the type_id. then fall back to + // converted_type if logical_type isn't set // Logical type used for actual data interpretation; the legacy converted type // is superseded by 'logical' type whenever available. auto const inferred_converted_type = logical_type_to_converted_type(logical_type); if (inferred_converted_type != UNKNOWN) { converted_type = inferred_converted_type; } - if (inferred_converted_type == DECIMAL) { - decimal_precision = schema.logical_type.DECIMAL.precision; - } + if (inferred_converted_type == DECIMAL) { decimal_precision = schema.logical_type->precision(); } - switch (converted_type) { + switch (converted_type.value_or(UNKNOWN)) { case UINT_8: return type_id::UINT8; case INT_8: return type_id::INT8; case UINT_16: return type_id::UINT16; @@ -140,15 +138,13 @@ type_id to_type_id(SchemaElement const& schema, default: break; } - if (inferred_converted_type == UNKNOWN and physical == INT64 and - logical_type.TIMESTAMP.unit.isset.NANOS) { - return (timestamp_type_id != type_id::EMPTY) ? timestamp_type_id - : type_id::TIMESTAMP_NANOSECONDS; - } - - if (inferred_converted_type == UNKNOWN and physical == INT64 and - logical_type.TIME.unit.isset.NANOS) { - return type_id::DURATION_NANOSECONDS; + if (inferred_converted_type == UNKNOWN and physical == INT64 and logical_type.has_value()) { + if (logical_type->is_timestamp_nanos()) { + return (timestamp_type_id != type_id::EMPTY) ? timestamp_type_id + : type_id::TIMESTAMP_NANOSECONDS; + } else if (logical_type->is_time_nanos()) { + return type_id::DURATION_NANOSECONDS; + } } // is it simply a struct? diff --git a/cpp/src/io/parquet/writer_impl.cu b/cpp/src/io/parquet/writer_impl.cu index 50589f23626..c06acc1690b 100644 --- a/cpp/src/io/parquet/writer_impl.cu +++ b/cpp/src/io/parquet/writer_impl.cu @@ -284,6 +284,7 @@ struct leaf_schema_fn { { col_schema.type = Type::BOOLEAN; col_schema.stats_dtype = statistics_dtype::dtype_bool; + // BOOLEAN needs no converted or logical type } template @@ -292,6 +293,7 @@ struct leaf_schema_fn { col_schema.type = Type::INT32; col_schema.converted_type = ConvertedType::INT_8; col_schema.stats_dtype = statistics_dtype::dtype_int8; + col_schema.logical_type = LogicalType{IntType{8, true}}; } template @@ -300,6 +302,7 @@ struct leaf_schema_fn { col_schema.type = Type::INT32; col_schema.converted_type = ConvertedType::INT_16; col_schema.stats_dtype = statistics_dtype::dtype_int16; + col_schema.logical_type = LogicalType{IntType{16, true}}; } template @@ -307,6 +310,7 @@ struct leaf_schema_fn { { col_schema.type = Type::INT32; col_schema.stats_dtype = statistics_dtype::dtype_int32; + // INT32 needs no converted or logical type } template @@ -314,6 +318,7 @@ struct leaf_schema_fn { { col_schema.type = Type::INT64; col_schema.stats_dtype = statistics_dtype::dtype_int64; + // INT64 needs no converted or logical type } template @@ -322,6 +327,7 @@ struct leaf_schema_fn { col_schema.type = Type::INT32; col_schema.converted_type = ConvertedType::UINT_8; col_schema.stats_dtype = statistics_dtype::dtype_int8; + col_schema.logical_type = LogicalType{IntType{8, false}}; } template @@ -330,6 +336,7 @@ struct leaf_schema_fn { col_schema.type = Type::INT32; col_schema.converted_type = ConvertedType::UINT_16; col_schema.stats_dtype = statistics_dtype::dtype_int16; + col_schema.logical_type = LogicalType{IntType{16, false}}; } template @@ -338,6 +345,7 @@ struct leaf_schema_fn { col_schema.type = Type::INT32; col_schema.converted_type = ConvertedType::UINT_32; col_schema.stats_dtype = statistics_dtype::dtype_int32; + col_schema.logical_type = LogicalType{IntType{32, false}}; } template @@ -346,6 +354,7 @@ struct leaf_schema_fn { col_schema.type = Type::INT64; col_schema.converted_type = ConvertedType::UINT_64; col_schema.stats_dtype = statistics_dtype::dtype_int64; + col_schema.logical_type = LogicalType{IntType{64, false}}; } template @@ -353,6 +362,7 @@ struct leaf_schema_fn { { col_schema.type = Type::FLOAT; col_schema.stats_dtype = statistics_dtype::dtype_float32; + // FLOAT needs no converted or logical type } template @@ -360,6 +370,7 @@ struct leaf_schema_fn { { col_schema.type = Type::DOUBLE; col_schema.stats_dtype = statistics_dtype::dtype_float64; + // DOUBLE needs no converted or logical type } template @@ -367,11 +378,12 @@ struct leaf_schema_fn { { col_schema.type = Type::BYTE_ARRAY; if (col_meta.is_enabled_output_as_binary()) { - col_schema.converted_type = ConvertedType::UNKNOWN; - col_schema.stats_dtype = statistics_dtype::dtype_byte_array; + col_schema.stats_dtype = statistics_dtype::dtype_byte_array; + // BYTE_ARRAY needs no converted or logical type } else { col_schema.converted_type = ConvertedType::UTF8; col_schema.stats_dtype = statistics_dtype::dtype_string; + col_schema.logical_type = LogicalType{LogicalType::STRING}; } } @@ -381,49 +393,55 @@ struct leaf_schema_fn { col_schema.type = Type::INT32; col_schema.converted_type = ConvertedType::DATE; col_schema.stats_dtype = statistics_dtype::dtype_int32; + col_schema.logical_type = LogicalType{LogicalType::DATE}; } template std::enable_if_t, void> operator()() { - col_schema.type = (timestamp_is_int96) ? Type::INT96 : Type::INT64; - col_schema.converted_type = - (timestamp_is_int96) ? ConvertedType::UNKNOWN : ConvertedType::TIMESTAMP_MILLIS; + col_schema.type = (timestamp_is_int96) ? Type::INT96 : Type::INT64; col_schema.stats_dtype = statistics_dtype::dtype_timestamp64; col_schema.ts_scale = 1000; + if (not timestamp_is_int96) { + col_schema.converted_type = ConvertedType::TIMESTAMP_MILLIS; + col_schema.logical_type = LogicalType{TimestampType{false, TimeUnit::MILLIS}}; + } } template std::enable_if_t, void> operator()() { - col_schema.type = (timestamp_is_int96) ? Type::INT96 : Type::INT64; - col_schema.converted_type = - (timestamp_is_int96) ? ConvertedType::UNKNOWN : ConvertedType::TIMESTAMP_MILLIS; + col_schema.type = (timestamp_is_int96) ? Type::INT96 : Type::INT64; col_schema.stats_dtype = statistics_dtype::dtype_timestamp64; + if (not timestamp_is_int96) { + col_schema.converted_type = ConvertedType::TIMESTAMP_MILLIS; + col_schema.logical_type = LogicalType{TimestampType{false, TimeUnit::MILLIS}}; + } } template std::enable_if_t, void> operator()() { - col_schema.type = (timestamp_is_int96) ? Type::INT96 : Type::INT64; - col_schema.converted_type = - (timestamp_is_int96) ? ConvertedType::UNKNOWN : ConvertedType::TIMESTAMP_MICROS; + col_schema.type = (timestamp_is_int96) ? Type::INT96 : Type::INT64; col_schema.stats_dtype = statistics_dtype::dtype_timestamp64; + if (not timestamp_is_int96) { + col_schema.converted_type = ConvertedType::TIMESTAMP_MICROS; + col_schema.logical_type = LogicalType{TimestampType{false, TimeUnit::MICROS}}; + } } template std::enable_if_t, void> operator()() { col_schema.type = (timestamp_is_int96) ? Type::INT96 : Type::INT64; - col_schema.converted_type = ConvertedType::UNKNOWN; + col_schema.converted_type = thrust::nullopt; col_schema.stats_dtype = statistics_dtype::dtype_timestamp64; if (timestamp_is_int96) { col_schema.ts_scale = -1000; // negative value indicates division by absolute value } // set logical type if it's not int96 else { - col_schema.logical_type.isset.TIMESTAMP = true; - col_schema.logical_type.TIMESTAMP.unit.isset.NANOS = true; + col_schema.logical_type = LogicalType{TimestampType{false, TimeUnit::NANOS}}; } } @@ -431,53 +449,48 @@ struct leaf_schema_fn { template std::enable_if_t, void> operator()() { - col_schema.type = Type::INT32; - col_schema.converted_type = ConvertedType::TIME_MILLIS; - col_schema.stats_dtype = statistics_dtype::dtype_int32; - col_schema.ts_scale = 24 * 60 * 60 * 1000; - col_schema.logical_type.isset.TIME = true; - col_schema.logical_type.TIME.unit.isset.MILLIS = true; + col_schema.type = Type::INT32; + col_schema.converted_type = ConvertedType::TIME_MILLIS; + col_schema.stats_dtype = statistics_dtype::dtype_int32; + col_schema.ts_scale = 24 * 60 * 60 * 1000; + col_schema.logical_type = LogicalType{TimeType{false, TimeUnit::MILLIS}}; } template std::enable_if_t, void> operator()() { - col_schema.type = Type::INT32; - col_schema.converted_type = ConvertedType::TIME_MILLIS; - col_schema.stats_dtype = statistics_dtype::dtype_int32; - col_schema.ts_scale = 1000; - col_schema.logical_type.isset.TIME = true; - col_schema.logical_type.TIME.unit.isset.MILLIS = true; + col_schema.type = Type::INT32; + col_schema.converted_type = ConvertedType::TIME_MILLIS; + col_schema.stats_dtype = statistics_dtype::dtype_int32; + col_schema.ts_scale = 1000; + col_schema.logical_type = LogicalType{TimeType{false, TimeUnit::MILLIS}}; } template std::enable_if_t, void> operator()() { - col_schema.type = Type::INT32; - col_schema.converted_type = ConvertedType::TIME_MILLIS; - col_schema.stats_dtype = statistics_dtype::dtype_int32; - col_schema.logical_type.isset.TIME = true; - col_schema.logical_type.TIME.unit.isset.MILLIS = true; + col_schema.type = Type::INT32; + col_schema.converted_type = ConvertedType::TIME_MILLIS; + col_schema.stats_dtype = statistics_dtype::dtype_int32; + col_schema.logical_type = LogicalType{TimeType{false, TimeUnit::MILLIS}}; } template std::enable_if_t, void> operator()() { - col_schema.type = Type::INT64; - col_schema.converted_type = ConvertedType::TIME_MICROS; - col_schema.stats_dtype = statistics_dtype::dtype_int64; - col_schema.logical_type.isset.TIME = true; - col_schema.logical_type.TIME.unit.isset.MICROS = true; + col_schema.type = Type::INT64; + col_schema.converted_type = ConvertedType::TIME_MICROS; + col_schema.stats_dtype = statistics_dtype::dtype_int64; + col_schema.logical_type = LogicalType{TimeType{false, TimeUnit::MICROS}}; } // unsupported outside cudf for parquet 1.0. template std::enable_if_t, void> operator()() { - col_schema.type = Type::INT64; - col_schema.stats_dtype = statistics_dtype::dtype_int64; - col_schema.logical_type.isset.TIME = true; - col_schema.logical_type.TIME.unit.isset.NANOS = true; + col_schema.type = Type::INT64; + col_schema.stats_dtype = statistics_dtype::dtype_int64; + col_schema.logical_type = LogicalType{TimeType{false, TimeUnit::NANOS}}; } template @@ -487,27 +500,32 @@ struct leaf_schema_fn { col_schema.type = Type::INT32; col_schema.stats_dtype = statistics_dtype::dtype_int32; col_schema.decimal_precision = MAX_DECIMAL32_PRECISION; + col_schema.logical_type = LogicalType{DecimalType{0, MAX_DECIMAL32_PRECISION}}; } else if (std::is_same_v) { col_schema.type = Type::INT64; col_schema.stats_dtype = statistics_dtype::dtype_decimal64; col_schema.decimal_precision = MAX_DECIMAL64_PRECISION; + col_schema.logical_type = LogicalType{DecimalType{0, MAX_DECIMAL64_PRECISION}}; } else if (std::is_same_v) { col_schema.type = Type::FIXED_LEN_BYTE_ARRAY; col_schema.type_length = sizeof(__int128_t); col_schema.stats_dtype = statistics_dtype::dtype_decimal128; col_schema.decimal_precision = MAX_DECIMAL128_PRECISION; + col_schema.logical_type = LogicalType{DecimalType{0, MAX_DECIMAL128_PRECISION}}; } else { CUDF_FAIL("Unsupported fixed point type for parquet writer"); } col_schema.converted_type = ConvertedType::DECIMAL; col_schema.decimal_scale = -col->type().scale(); // parquet and cudf disagree about scale signs + col_schema.logical_type->decimal_type->scale = -col->type().scale(); if (col_meta.is_decimal_precision_set()) { CUDF_EXPECTS(col_meta.get_decimal_precision() >= col_schema.decimal_scale, "Precision must be equal to or greater than scale!"); if (col_schema.type == Type::INT64 and col_meta.get_decimal_precision() < 10) { CUDF_LOG_WARN("Parquet writer: writing a decimal column with precision < 10 as int64"); } - col_schema.decimal_precision = col_meta.get_decimal_precision(); + col_schema.decimal_precision = col_meta.get_decimal_precision(); + col_schema.logical_type->decimal_type->precision = col_meta.get_decimal_precision(); } } @@ -593,7 +611,7 @@ std::vector construct_schema_tree( schema_tree_node col_schema{}; col_schema.type = Type::BYTE_ARRAY; - col_schema.converted_type = ConvertedType::UNKNOWN; + col_schema.converted_type = thrust::nullopt; col_schema.stats_dtype = statistics_dtype::dtype_byte_array; col_schema.repetition_type = col_nullable ? OPTIONAL : REQUIRED; col_schema.name = (schema[parent_idx].name == "list") ? "element" : col_meta.get_name(); @@ -762,7 +780,10 @@ struct parquet_column_view { [[nodiscard]] column_view cudf_column_view() const { return cudf_col; } [[nodiscard]] Type physical_type() const { return schema_node.type; } - [[nodiscard]] ConvertedType converted_type() const { return schema_node.converted_type; } + [[nodiscard]] ConvertedType converted_type() const + { + return schema_node.converted_type.value_or(UNKNOWN); + } std::vector const& get_path_in_schema() { return path_in_schema; } diff --git a/cpp/tests/io/parquet_test.cpp b/cpp/tests/io/parquet_test.cpp index 2a654bd7e8c..fece83f891b 100644 --- a/cpp/tests/io/parquet_test.cpp +++ b/cpp/tests/io/parquet_test.cpp @@ -4075,11 +4075,12 @@ int32_t compare(T& v1, T& v2) int32_t compare_binary(std::vector const& v1, std::vector const& v2, cudf::io::parquet::detail::Type ptype, - cudf::io::parquet::detail::ConvertedType ctype) + thrust::optional const& ctype) { + auto ctype_val = ctype.value_or(cudf::io::parquet::detail::UNKNOWN); switch (ptype) { case cudf::io::parquet::detail::INT32: - switch (ctype) { + switch (ctype_val) { case cudf::io::parquet::detail::UINT_8: case cudf::io::parquet::detail::UINT_16: case cudf::io::parquet::detail::UINT_32: @@ -4091,7 +4092,7 @@ int32_t compare_binary(std::vector const& v1, } case cudf::io::parquet::detail::INT64: - if (ctype == cudf::io::parquet::detail::UINT_64) { + if (ctype_val == cudf::io::parquet::detail::UINT_64) { return compare(*(reinterpret_cast(v1.data())), *(reinterpret_cast(v2.data()))); }