Skip to content

Commit

Permalink
Add column field ID control in parquet writer (#10504)
Browse files Browse the repository at this point in the history
Closes #10375
Closes #10376

This PR enables column `field_id` control in the parquet writer. When writing a parquet file, users can specify a column's `field_id` via `column_in_metadata.set_parquet_field_id()`. JNI bindings and uni tests are added as well.

Authors:
  - Yunsong Wang (https://github.com/PointKernel)

Approvers:
  - Jason Lowe (https://github.com/jlowe)
  - Vukasin Milovanovic (https://github.com/vuule)
  - Devavret Makkar (https://github.com/devavret)

URL: #10504
  • Loading branch information
PointKernel authored Apr 15, 2022
1 parent 9e1258d commit d5a982b
Show file tree
Hide file tree
Showing 12 changed files with 368 additions and 27 deletions.
27 changes: 26 additions & 1 deletion cpp/include/cudf/io/types.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -244,6 +244,7 @@ class column_in_metadata {
bool _use_int96_timestamp = false;
// bool _output_as_binary = false;
thrust::optional<uint8_t> _decimal_precision;
thrust::optional<int32_t> _parquet_field_id;
std::vector<column_in_metadata> children;

public:
Expand Down Expand Up @@ -324,6 +325,18 @@ class column_in_metadata {
return *this;
}

/**
* @brief Set the parquet field id of this column.
*
* @param field_id The parquet field id to set
* @return this for chaining
*/
column_in_metadata& set_parquet_field_id(int32_t field_id)
{
_parquet_field_id = field_id;
return *this;
}

/**
* @brief Get reference to a child of this column
*
Expand Down Expand Up @@ -379,6 +392,18 @@ class column_in_metadata {
*/
[[nodiscard]] uint8_t get_decimal_precision() const { return _decimal_precision.value(); }

/**
* @brief Get whether parquet field id has been set for this column.
*/
[[nodiscard]] bool is_parquet_field_id_set() const { return _parquet_field_id.has_value(); }

/**
* @brief Get the parquet field id that was set for this column.
* @throws If parquet field id was not set for this column.
* Check using `is_parquet_field_id_set()` first.
*/
[[nodiscard]] int32_t get_parquet_field_id() const { return _parquet_field_id.value(); }

/**
* @brief Get the number of children of this column
*/
Expand Down
1 change: 1 addition & 0 deletions cpp/src/io/parquet/compact_protocol_reader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ bool CompactProtocolReader::read(SchemaElement* s)
ParquetFieldEnum<ConvertedType>(6, s->converted_type),
ParquetFieldInt32(7, s->decimal_scale),
ParquetFieldInt32(8, s->decimal_precision),
ParquetFieldOptionalInt32(9, s->field_id),
ParquetFieldStruct(10, s->logical_type));
return function_builder(this, op);
}
Expand Down
24 changes: 24 additions & 0 deletions cpp/src/io/parquet/compact_protocol_reader.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

#include "parquet.hpp"

#include <thrust/optional.h>

#include <algorithm>
#include <cstddef>
#include <string>
Expand Down Expand Up @@ -137,6 +139,7 @@ class CompactProtocolReader {
friend class ParquetFieldBool;
friend class ParquetFieldInt8;
friend class ParquetFieldInt32;
friend class ParquetFieldOptionalInt32;
friend class ParquetFieldInt64;
template <typename T>
friend class ParquetFieldStructListFunctor;
Expand Down Expand Up @@ -216,6 +219,27 @@ class ParquetFieldInt32 {
int field() { return field_val; }
};

/**
* @brief Functor to set value to optional 32 bit integer read from CompactProtocolReader
*
* @return True if field type is not int32
*/
class ParquetFieldOptionalInt32 {
int field_val;
thrust::optional<int32_t>& val;

public:
ParquetFieldOptionalInt32(int f, thrust::optional<int32_t>& v) : field_val(f), val(v) {}

inline bool operator()(CompactProtocolReader* cpr, int field_type)
{
val = cpr->get_i32();
return (field_type != ST_FLD_I32);
}

int field() { return field_val; }
};

/**
* @brief Functor to set value to 64 bit integer read from CompactProtocolReader
*
Expand Down
1 change: 1 addition & 0 deletions cpp/src/io/parquet/compact_protocol_writer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ size_t CompactProtocolWriter::write(const SchemaElement& s)
c.field_int(8, s.decimal_precision);
}
}
if (s.field_id) { c.field_int(9, s.field_id.value()); }
auto const isset = s.logical_type.isset;
// TODO: add handling for all logical types
// if (isset.STRING or isset.MAP or isset.LIST or isset.ENUM or isset.DECIMAL or isset.DATE or
Expand Down
6 changes: 5 additions & 1 deletion cpp/src/io/parquet/parquet.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

#include "parquet_common.hpp"

#include <thrust/optional.h>

#include <cstdint>
#include <string>
#include <vector>
Expand Down Expand Up @@ -145,6 +147,7 @@ struct SchemaElement {
int32_t num_children = 0;
int32_t decimal_scale = 0;
int32_t decimal_precision = 0;
thrust::optional<int32_t> field_id = thrust::nullopt;

// The following fields are filled in later during schema initialization
int max_definition_level = 0;
Expand All @@ -157,7 +160,8 @@ struct SchemaElement {
return type == other.type && converted_type == other.converted_type &&
type_length == other.type_length && repetition_type == other.repetition_type &&
name == other.name && num_children == other.num_children &&
decimal_scale == other.decimal_scale && decimal_precision == other.decimal_precision;
decimal_scale == other.decimal_scale && decimal_precision == other.decimal_precision &&
field_id == other.field_id;
}

// the parquet format is a little squishy when it comes to interpreting
Expand Down
15 changes: 14 additions & 1 deletion cpp/src/io/parquet/writer_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,13 @@ std::vector<schema_tree_node> construct_schema_tree(
[&](cudf::detail::LinkedColPtr const& col, column_in_metadata& col_meta, size_t parent_idx) {
bool col_nullable = is_col_nullable(col, col_meta, single_write_mode);

auto set_field_id = [&schema, parent_idx](schema_tree_node& s,
column_in_metadata const& col_meta) {
if (schema[parent_idx].name != "list" and col_meta.is_parquet_field_id_set()) {
s.field_id = col_meta.get_parquet_field_id();
}
};

if (col->type().id() == type_id::STRUCT) {
// if struct, add current and recursively call for all children
schema_tree_node struct_schema{};
Expand All @@ -500,6 +507,7 @@ std::vector<schema_tree_node> construct_schema_tree(
struct_schema.name = (schema[parent_idx].name == "list") ? "element" : col_meta.get_name();
struct_schema.num_children = col->children.size();
struct_schema.parent_idx = parent_idx;
set_field_id(struct_schema, col_meta);
schema.push_back(std::move(struct_schema));

auto struct_node_index = schema.size() - 1;
Expand All @@ -524,6 +532,7 @@ std::vector<schema_tree_node> construct_schema_tree(
list_schema_1.name = (schema[parent_idx].name == "list") ? "element" : col_meta.get_name();
list_schema_1.num_children = 1;
list_schema_1.parent_idx = parent_idx;
set_field_id(list_schema_1, col_meta);
schema.push_back(std::move(list_schema_1));

schema_tree_node list_schema_2{};
Expand Down Expand Up @@ -555,7 +564,10 @@ std::vector<schema_tree_node> construct_schema_tree(
map_schema.converted_type = ConvertedType::MAP;
map_schema.repetition_type =
col_nullable ? FieldRepetitionType::OPTIONAL : FieldRepetitionType::REQUIRED;
map_schema.name = col_meta.get_name();
map_schema.name = col_meta.get_name();
if (col_meta.is_parquet_field_id_set()) {
map_schema.field_id = col_meta.get_parquet_field_id();
}
map_schema.num_children = 1;
map_schema.parent_idx = parent_idx;
schema.push_back(std::move(map_schema));
Expand Down Expand Up @@ -612,6 +624,7 @@ std::vector<schema_tree_node> construct_schema_tree(
col_schema.name = (schema[parent_idx].name == "list") ? "element" : col_meta.get_name();
col_schema.parent_idx = parent_idx;
col_schema.leaf_column = col;
set_field_id(col_schema, col_meta);
schema.push_back(col_schema);
}
};
Expand Down
10 changes: 8 additions & 2 deletions cpp/tests/io/parquet_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,15 +219,21 @@ struct ParquetWriterTimestampTypeTest : public ParquetWriterTest {
auto type() { return cudf::data_type{cudf::type_to_id<T>()}; }
};

// Typed test fixture for all types
template <typename T>
struct ParquetWriterSchemaTest : public ParquetWriterTest {
auto type() { return cudf::data_type{cudf::type_to_id<T>()}; }
};

// Declare typed test cases
// TODO: Replace with `NumericTypes` when unsigned support is added. Issue #5352
using SupportedTypes = cudf::test::Types<int8_t, int16_t, int32_t, int64_t, bool, float, double>;
TYPED_TEST_SUITE(ParquetWriterNumericTypeTest, SupportedTypes);
using SupportedChronoTypes = cudf::test::Concat<cudf::test::ChronoTypes, cudf::test::DurationTypes>;
TYPED_TEST_SUITE(ParquetWriterChronoTypeTest, SupportedChronoTypes);
TYPED_TEST_SUITE(ParquetWriterChronoTypeTest, cudf::test::ChronoTypes);
using SupportedTimestampTypes =
cudf::test::Types<cudf::timestamp_ms, cudf::timestamp_us, cudf::timestamp_ns>;
TYPED_TEST_SUITE(ParquetWriterTimestampTypeTest, SupportedTimestampTypes);
TYPED_TEST_SUITE(ParquetWriterSchemaTest, cudf::test::AllTypes);

// Base test fixture for chunked writer tests
struct ParquetChunkedWriterTest : public cudf::test::BaseFixture {
Expand Down
Loading

0 comments on commit d5a982b

Please sign in to comment.