Skip to content

Commit

Permalink
Convert FieldType to scoped enum (#14642)
Browse files Browse the repository at this point in the history
Switch to scoped enum (`enum class`); they are better because, well, values now have a scope.
Another benefit in this case - values are now named consistently with compact protocol.
De-duplicated some code, now that more static_casts are required and duplication stands out more.

Authors:
  - Vukasin Milovanovic (https://github.com/vuule)
  - Nghia Truong (https://github.com/ttnghia)

Approvers:
  - Nghia Truong (https://github.com/ttnghia)
  - MithunR (https://github.com/mythrocks)
  - Vyas Ramasubramani (https://github.com/vyasr)

URL: #14642
  • Loading branch information
vuule authored Jan 9, 2024
1 parent 856c985 commit 433bdc3
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 123 deletions.
110 changes: 61 additions & 49 deletions cpp/src/io/parquet/compact_protocol_reader.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2023, NVIDIA CORPORATION.
* Copyright (c) 2018-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -45,28 +45,37 @@ class parquet_field {
std::string field_type_string(FieldType type)
{
switch (type) {
case ST_FLD_TRUE: return "bool(true)";
case ST_FLD_FALSE: return "bool(false)";
case ST_FLD_BYTE: return "int8";
case ST_FLD_I16: return "int16";
case ST_FLD_I32: return "int32";
case ST_FLD_I64: return "int64";
case ST_FLD_DOUBLE: return "double";
case ST_FLD_BINARY: return "binary";
case ST_FLD_STRUCT: return "struct";
case ST_FLD_LIST: return "list";
case ST_FLD_SET: return "set";
default: return "unknown(" + std::to_string(type) + ")";
case FieldType::BOOLEAN_TRUE: return "bool(true)";
case FieldType::BOOLEAN_FALSE: return "bool(false)";
case FieldType::I8: return "int8";
case FieldType::I16: return "int16";
case FieldType::I32: return "int32";
case FieldType::I64: return "int64";
case FieldType::DOUBLE: return "double";
case FieldType::BINARY: return "binary";
case FieldType::LIST: return "list";
case FieldType::SET: return "set";
case FieldType::MAP: return "map";
case FieldType::STRUCT: return "struct";
case FieldType::UUID: return "UUID";
default: return "unknown(" + std::to_string(static_cast<uint8_t>(type)) + ")";
}
}

void assert_field_type(int type, FieldType expected)
{
CUDF_EXPECTS(type == expected,
CUDF_EXPECTS(type == static_cast<int>(expected),
"expected " + field_type_string(expected) + " field, got " +
field_type_string(static_cast<FieldType>(type)) + " field instead");
}

void assert_bool_field_type(int type)
{
auto const field_type = static_cast<FieldType>(type);
CUDF_EXPECTS(field_type == FieldType::BOOLEAN_TRUE || field_type == FieldType::BOOLEAN_FALSE,
"expected bool field, got " + field_type_string(field_type) + " field instead");
}

/**
* @brief Abstract base class for list functors.
*/
Expand All @@ -86,7 +95,7 @@ class parquet_field_list : public parquet_field {
public:
inline void operator()(CompactProtocolReader* cpr, int field_type)
{
assert_field_type(field_type, ST_FLD_LIST);
assert_field_type(field_type, FieldType::LIST);
auto const [t, n] = cpr->get_listh();
assert_field_type(t, EXPECTED_ELEM_TYPE);
val.resize(n);
Expand All @@ -111,8 +120,8 @@ class parquet_field_bool : public parquet_field {

inline void operator()(CompactProtocolReader* cpr, int field_type)
{
CUDF_EXPECTS(field_type == ST_FLD_TRUE || field_type == ST_FLD_FALSE, "expected bool field");
val = field_type == ST_FLD_TRUE;
assert_bool_field_type(field_type);
val = field_type == static_cast<int>(FieldType::BOOLEAN_TRUE);
}
};

Expand All @@ -122,14 +131,13 @@ class parquet_field_bool : public parquet_field {
* @return True if field types mismatch or if the process of reading a
* bool fails
*/
struct parquet_field_bool_list : public parquet_field_list<bool, ST_FLD_TRUE> {
struct parquet_field_bool_list : public parquet_field_list<bool, FieldType::BOOLEAN_TRUE> {
parquet_field_bool_list(int f, std::vector<bool>& v) : parquet_field_list(f, v)
{
auto const read_value = [this](uint32_t i, CompactProtocolReader* cpr) {
auto const current_byte = cpr->getb();
CUDF_EXPECTS(current_byte == ST_FLD_TRUE || current_byte == ST_FLD_FALSE,
"expected bool field");
this->val[i] = current_byte == ST_FLD_TRUE;
assert_bool_field_type(current_byte);
this->val[i] = current_byte == static_cast<int>(FieldType::BOOLEAN_TRUE);
};
bind_read_func(read_value);
}
Expand Down Expand Up @@ -162,9 +170,9 @@ class parquet_field_int : public parquet_field {
}
};

using parquet_field_int8 = parquet_field_int<int8_t, ST_FLD_BYTE>;
using parquet_field_int32 = parquet_field_int<int32_t, ST_FLD_I32>;
using parquet_field_int64 = parquet_field_int<int64_t, ST_FLD_I64>;
using parquet_field_int8 = parquet_field_int<int8_t, FieldType::I8>;
using parquet_field_int32 = parquet_field_int<int32_t, FieldType::I32>;
using parquet_field_int64 = parquet_field_int<int64_t, FieldType::I64>;

/**
* @brief Functor to read a vector of integers from CompactProtocolReader
Expand All @@ -183,7 +191,7 @@ struct parquet_field_int_list : public parquet_field_list<T, EXPECTED_TYPE> {
}
};

using parquet_field_int64_list = parquet_field_int_list<int64_t, ST_FLD_I64>;
using parquet_field_int64_list = parquet_field_int_list<int64_t, FieldType::I64>;

/**
* @brief Functor to read a string from CompactProtocolReader
Expand All @@ -199,7 +207,7 @@ class parquet_field_string : public parquet_field {

inline void operator()(CompactProtocolReader* cpr, int field_type)
{
assert_field_type(field_type, ST_FLD_BINARY);
assert_field_type(field_type, FieldType::BINARY);
auto const n = cpr->get_u32();
CUDF_EXPECTS(n < static_cast<size_t>(cpr->m_end - cpr->m_cur), "string length mismatch");

Expand All @@ -214,7 +222,7 @@ class parquet_field_string : public parquet_field {
* @return True if field types mismatch or if the process of reading a
* string fails
*/
struct parquet_field_string_list : public parquet_field_list<std::string, ST_FLD_BINARY> {
struct parquet_field_string_list : public parquet_field_list<std::string, FieldType::BINARY> {
parquet_field_string_list(int f, std::vector<std::string>& v) : parquet_field_list(f, v)
{
auto const read_value = [this](uint32_t i, CompactProtocolReader* cpr) {
Expand All @@ -241,7 +249,7 @@ class parquet_field_enum : public parquet_field {
parquet_field_enum(int f, Enum& v) : parquet_field(f), val(v) {}
inline void operator()(CompactProtocolReader* cpr, int field_type)
{
assert_field_type(field_type, ST_FLD_I32);
assert_field_type(field_type, FieldType::I32);
val = static_cast<Enum>(cpr->get_i32());
}
};
Expand All @@ -253,8 +261,9 @@ class parquet_field_enum : public parquet_field {
* enum fails
*/
template <typename Enum>
struct parquet_field_enum_list : public parquet_field_list<Enum, ST_FLD_I32> {
parquet_field_enum_list(int f, std::vector<Enum>& v) : parquet_field_list<Enum, ST_FLD_I32>(f, v)
struct parquet_field_enum_list : public parquet_field_list<Enum, FieldType::I32> {
parquet_field_enum_list(int f, std::vector<Enum>& v)
: parquet_field_list<Enum, FieldType::I32>(f, v)
{
auto const read_value = [this](uint32_t i, CompactProtocolReader* cpr) {
this->val[i] = static_cast<Enum>(cpr->get_i32());
Expand All @@ -278,7 +287,7 @@ class parquet_field_struct : public parquet_field {

inline void operator()(CompactProtocolReader* cpr, int field_type)
{
assert_field_type(field_type, ST_FLD_STRUCT);
assert_field_type(field_type, FieldType::STRUCT);
cpr->read(&val);
}
};
Expand Down Expand Up @@ -324,7 +333,7 @@ class parquet_field_union_enumerator : public parquet_field {

inline void operator()(CompactProtocolReader* cpr, int field_type)
{
assert_field_type(field_type, ST_FLD_STRUCT);
assert_field_type(field_type, FieldType::STRUCT);
cpr->skip_struct_field(field_type);
val = static_cast<E>(field());
}
Expand All @@ -337,8 +346,9 @@ class parquet_field_union_enumerator : public parquet_field {
* struct fails
*/
template <typename T>
struct parquet_field_struct_list : public parquet_field_list<T, ST_FLD_STRUCT> {
parquet_field_struct_list(int f, std::vector<T>& v) : parquet_field_list<T, ST_FLD_STRUCT>(f, v)
struct parquet_field_struct_list : public parquet_field_list<T, FieldType::STRUCT> {
parquet_field_struct_list(int f, std::vector<T>& v)
: parquet_field_list<T, FieldType::STRUCT>(f, v)
{
auto const read_value = [this](uint32_t i, CompactProtocolReader* cpr) {
cpr->read(&this->val[i]);
Expand All @@ -361,7 +371,7 @@ class parquet_field_binary : public parquet_field {

inline void operator()(CompactProtocolReader* cpr, int field_type)
{
assert_field_type(field_type, ST_FLD_BINARY);
assert_field_type(field_type, FieldType::BINARY);
auto const n = cpr->get_u32();
CUDF_EXPECTS(n <= static_cast<size_t>(cpr->m_end - cpr->m_cur), "binary length mismatch");

Expand All @@ -377,7 +387,8 @@ class parquet_field_binary : public parquet_field {
* @return True if field types mismatch or if the process of reading a
* binary fails
*/
struct parquet_field_binary_list : public parquet_field_list<std::vector<uint8_t>, ST_FLD_BINARY> {
struct parquet_field_binary_list
: public parquet_field_list<std::vector<uint8_t>, FieldType::BINARY> {
parquet_field_binary_list(int f, std::vector<std::vector<uint8_t>>& v) : parquet_field_list(f, v)
{
auto const read_value = [this](uint32_t i, CompactProtocolReader* cpr) {
Expand All @@ -404,7 +415,7 @@ class parquet_field_struct_blob : public parquet_field {
parquet_field_struct_blob(int f, std::vector<uint8_t>& v) : parquet_field(f), val(v) {}
inline void operator()(CompactProtocolReader* cpr, int field_type)
{
assert_field_type(field_type, ST_FLD_STRUCT);
assert_field_type(field_type, FieldType::STRUCT);
uint8_t const* const start = cpr->m_cur;
cpr->skip_struct_field(field_type);
if (cpr->m_cur > start) { val.assign(start, cpr->m_cur - 1); }
Expand Down Expand Up @@ -439,24 +450,25 @@ class parquet_field_optional : public parquet_field {
*/
void CompactProtocolReader::skip_struct_field(int t, int depth)
{
switch (t) {
case ST_FLD_TRUE:
case ST_FLD_FALSE: break;
case ST_FLD_I16:
case ST_FLD_I32:
case ST_FLD_I64: get_u64(); break;
case ST_FLD_BYTE: skip_bytes(1); break;
case ST_FLD_DOUBLE: skip_bytes(8); break;
case ST_FLD_BINARY: skip_bytes(get_u32()); break;
case ST_FLD_LIST: [[fallthrough]];
case ST_FLD_SET: {
auto const t_enum = static_cast<FieldType>(t);
switch (t_enum) {
case FieldType::BOOLEAN_TRUE:
case FieldType::BOOLEAN_FALSE: break;
case FieldType::I16:
case FieldType::I32:
case FieldType::I64: get_u64(); break;
case FieldType::I8: skip_bytes(1); break;
case FieldType::DOUBLE: skip_bytes(8); break;
case FieldType::BINARY: skip_bytes(get_u32()); break;
case FieldType::LIST:
case FieldType::SET: {
auto const [t, n] = get_listh();
CUDF_EXPECTS(depth <= 10, "struct nesting too deep");
for (uint32_t i = 0; i < n; i++) {
skip_struct_field(t, depth + 1);
}
} break;
case ST_FLD_STRUCT:
case FieldType::STRUCT:
for (;;) {
int const c = getb();
t = c & 0xf;
Expand Down
43 changes: 22 additions & 21 deletions cpp/src/io/parquet/compact_protocol_writer.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2023, NVIDIA CORPORATION.
* Copyright (c) 2018-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -268,39 +268,40 @@ uint32_t CompactProtocolFieldWriter::put_int(int64_t v)
return put_uint(((v ^ -s) << 1) + s);
}

void CompactProtocolFieldWriter::put_field_header(int f, int cur, int t)
void CompactProtocolFieldWriter::put_field_header(int f, int cur, FieldType t)
{
if (f > cur && f <= cur + 15)
put_byte(((f - cur) << 4) | t);
put_packed_type_byte(f - cur, t);
else {
put_byte(t);
put_byte(static_cast<uint8_t>(t));
put_int(f);
}
}

inline void CompactProtocolFieldWriter::field_bool(int field, bool b)
{
put_field_header(field, current_field_value, b ? ST_FLD_TRUE : ST_FLD_FALSE);
put_field_header(
field, current_field_value, b ? FieldType::BOOLEAN_TRUE : FieldType::BOOLEAN_FALSE);
current_field_value = field;
}

inline void CompactProtocolFieldWriter::field_int8(int field, int8_t val)
{
put_field_header(field, current_field_value, ST_FLD_BYTE);
put_field_header(field, current_field_value, FieldType::I8);
put_byte(val);
current_field_value = field;
}

inline void CompactProtocolFieldWriter::field_int(int field, int32_t val)
{
put_field_header(field, current_field_value, ST_FLD_I32);
put_field_header(field, current_field_value, FieldType::I32);
put_int(val);
current_field_value = field;
}

inline void CompactProtocolFieldWriter::field_int(int field, int64_t val)
{
put_field_header(field, current_field_value, ST_FLD_I64);
put_field_header(field, current_field_value, FieldType::I64);
put_int(val);
current_field_value = field;
}
Expand All @@ -309,8 +310,8 @@ template <>
inline void CompactProtocolFieldWriter::field_int_list<int64_t>(int field,
std::vector<int64_t> const& val)
{
put_field_header(field, current_field_value, ST_FLD_LIST);
put_byte(static_cast<uint8_t>((std::min(val.size(), 0xfUL) << 4) | ST_FLD_I64));
put_field_header(field, current_field_value, FieldType::LIST);
put_packed_type_byte(val.size(), FieldType::I64);
if (val.size() >= 0xfUL) { put_uint(val.size()); }
for (auto const v : val) {
put_int(v);
Expand All @@ -321,8 +322,8 @@ inline void CompactProtocolFieldWriter::field_int_list<int64_t>(int field,
template <typename Enum>
inline void CompactProtocolFieldWriter::field_int_list(int field, std::vector<Enum> const& val)
{
put_field_header(field, current_field_value, ST_FLD_LIST);
put_byte(static_cast<uint8_t>((std::min(val.size(), 0xfUL) << 4) | ST_FLD_I32));
put_field_header(field, current_field_value, FieldType::LIST);
put_packed_type_byte(val.size(), FieldType::I32);
if (val.size() >= 0xfUL) { put_uint(val.size()); }
for (auto const& v : val) {
put_int(static_cast<int32_t>(v));
Expand All @@ -333,7 +334,7 @@ inline void CompactProtocolFieldWriter::field_int_list(int field, std::vector<En
template <typename T>
inline void CompactProtocolFieldWriter::field_struct(int field, T const& val)
{
put_field_header(field, current_field_value, ST_FLD_STRUCT);
put_field_header(field, current_field_value, FieldType::STRUCT);
if constexpr (not std::is_empty_v<T>) {
writer.write(val); // write the struct if it's not empty
} else {
Expand All @@ -344,16 +345,16 @@ inline void CompactProtocolFieldWriter::field_struct(int field, T const& val)

inline void CompactProtocolFieldWriter::field_empty_struct(int field)
{
put_field_header(field, current_field_value, ST_FLD_STRUCT);
put_field_header(field, current_field_value, FieldType::STRUCT);
put_byte(0); // add a stop field
current_field_value = field;
}

template <typename T>
inline void CompactProtocolFieldWriter::field_struct_list(int field, std::vector<T> const& val)
{
put_field_header(field, current_field_value, ST_FLD_LIST);
put_byte((uint8_t)((std::min(val.size(), (size_t)0xfu) << 4) | ST_FLD_STRUCT));
put_field_header(field, current_field_value, FieldType::LIST);
put_packed_type_byte(val.size(), FieldType::STRUCT);
if (val.size() >= 0xf) put_uint(val.size());
for (auto& v : val) {
writer.write(v);
Expand All @@ -370,23 +371,23 @@ inline size_t CompactProtocolFieldWriter::value()
inline void CompactProtocolFieldWriter::field_struct_blob(int field,
std::vector<uint8_t> const& val)
{
put_field_header(field, current_field_value, ST_FLD_STRUCT);
put_field_header(field, current_field_value, FieldType::STRUCT);
put_byte(val.data(), static_cast<uint32_t>(val.size()));
put_byte(0);
current_field_value = field;
}

inline void CompactProtocolFieldWriter::field_binary(int field, std::vector<uint8_t> const& val)
{
put_field_header(field, current_field_value, ST_FLD_BINARY);
put_field_header(field, current_field_value, FieldType::BINARY);
put_uint(val.size());
put_byte(val.data(), static_cast<uint32_t>(val.size()));
current_field_value = field;
}

inline void CompactProtocolFieldWriter::field_string(int field, std::string const& val)
{
put_field_header(field, current_field_value, ST_FLD_BINARY);
put_field_header(field, current_field_value, FieldType::BINARY);
put_uint(val.size());
// FIXME : replace reinterpret_cast
put_byte(reinterpret_cast<uint8_t const*>(val.data()), static_cast<uint32_t>(val.size()));
Expand All @@ -396,8 +397,8 @@ inline void CompactProtocolFieldWriter::field_string(int field, std::string cons
inline void CompactProtocolFieldWriter::field_string_list(int field,
std::vector<std::string> const& val)
{
put_field_header(field, current_field_value, ST_FLD_LIST);
put_byte((uint8_t)((std::min(val.size(), (size_t)0xfu) << 4) | ST_FLD_BINARY));
put_field_header(field, current_field_value, FieldType::LIST);
put_packed_type_byte(val.size(), FieldType::BINARY);
if (val.size() >= 0xf) put_uint(val.size());
for (auto& v : val) {
put_uint(v.size());
Expand Down
Loading

0 comments on commit 433bdc3

Please sign in to comment.