Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor io::orc::ProtobufWriter #12877

Merged
merged 16 commits into from
Mar 14, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 20 additions & 20 deletions cpp/src/io/orc/orc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace cudf {
namespace io {
namespace orc {

uint32_t ProtobufReader::read_field_size(const uint8_t* end)
uint32_t ProtobufReader::read_field_size(uint8_t const* end)
{
auto const size = get<uint32_t>();
CUDF_EXPECTS(size <= static_cast<uint32_t>(end - m_cur), "Protobuf parsing out of bounds");
Expand Down Expand Up @@ -213,8 +213,7 @@ void ProtobufWriter::put_row_index_entry(int32_t present_blk,
TypeKind kind,
ColStatsBlob const* stats)
{
std::vector<uint8_t> positions_data;
ProtobufWriter position_writer(&positions_data);
ProtobufWriter position_writer;
auto const positions_size_offset = position_writer.put_uint(
encode_field_number(1, ProtofType::FIXEDLEN)); // 1:positions[packed=true]
position_writer.put_byte(0xcd); // positions size placeholder
Expand Down Expand Up @@ -246,19 +245,20 @@ void ProtobufWriter::put_row_index_entry(int32_t present_blk,
positions_size += position_writer.put_byte(0);
}
}

// size of the field 1
positions_data[positions_size_offset] = static_cast<uint8_t>(positions_size);
position_writer.buffer()[positions_size_offset] = static_cast<uint8_t>(positions_size);

auto const stats_size = (stats == nullptr)
? 0
: varint_size(encode_field_number<decltype(*stats)>(2)) +
varint_size(stats->size()) + stats->size();
auto const entry_size = positions_data.size() + stats_size;
auto const entry_size = position_writer.size() + stats_size;

// 1:RowIndex.entry
put_uint(encode_field_number(1, ProtofType::FIXEDLEN));
put_uint(entry_size);
put_bytes<uint8_t>(positions_data);
put_bytes<uint8_t>(position_writer.buffer());

if (stats != nullptr) {
put_uint(encode_field_number<decltype(*stats)>(2)); // 2: statistics
Expand All @@ -268,7 +268,7 @@ void ProtobufWriter::put_row_index_entry(int32_t present_blk,
}
}

size_t ProtobufWriter::write(const PostScript& s)
size_t ProtobufWriter::write(PostScript const& s)
{
ProtobufFieldWriter w(this);
w.field_uint(1, s.footerLength);
Expand All @@ -280,7 +280,7 @@ size_t ProtobufWriter::write(const PostScript& s)
return w.value();
}

size_t ProtobufWriter::write(const FileFooter& s)
size_t ProtobufWriter::write(FileFooter const& s)
{
ProtobufFieldWriter w(this);
w.field_uint(1, s.headerLength);
Expand All @@ -294,7 +294,7 @@ size_t ProtobufWriter::write(const FileFooter& s)
return w.value();
}

size_t ProtobufWriter::write(const StripeInformation& s)
size_t ProtobufWriter::write(StripeInformation const& s)
{
ProtobufFieldWriter w(this);
w.field_uint(1, s.offset);
Expand All @@ -305,7 +305,7 @@ size_t ProtobufWriter::write(const StripeInformation& s)
return w.value();
}

size_t ProtobufWriter::write(const SchemaType& s)
size_t ProtobufWriter::write(SchemaType const& s)
{
ProtobufFieldWriter w(this);
w.field_uint(1, s.kind);
Expand All @@ -317,15 +317,15 @@ size_t ProtobufWriter::write(const SchemaType& s)
return w.value();
}

size_t ProtobufWriter::write(const UserMetadataItem& s)
size_t ProtobufWriter::write(UserMetadataItem const& s)
{
ProtobufFieldWriter w(this);
w.field_blob(1, s.name);
w.field_blob(2, s.value);
return w.value();
}

size_t ProtobufWriter::write(const StripeFooter& s)
size_t ProtobufWriter::write(StripeFooter const& s)
{
ProtobufFieldWriter w(this);
w.field_repeated_struct(1, s.streams);
Expand All @@ -334,7 +334,7 @@ size_t ProtobufWriter::write(const StripeFooter& s)
return w.value();
}

size_t ProtobufWriter::write(const Stream& s)
size_t ProtobufWriter::write(Stream const& s)
{
ProtobufFieldWriter w(this);
w.field_uint(1, s.kind);
Expand All @@ -343,22 +343,22 @@ size_t ProtobufWriter::write(const Stream& s)
return w.value();
}

size_t ProtobufWriter::write(const ColumnEncoding& s)
size_t ProtobufWriter::write(ColumnEncoding const& s)
{
ProtobufFieldWriter w(this);
w.field_uint(1, s.kind);
if (s.kind == DICTIONARY || s.kind == DICTIONARY_V2) { w.field_uint(2, s.dictionarySize); }
return w.value();
}

size_t ProtobufWriter::write(const StripeStatistics& s)
size_t ProtobufWriter::write(StripeStatistics const& s)
{
ProtobufFieldWriter w(this);
w.field_repeated_struct_blob(1, s.colStats);
return w.value();
}

size_t ProtobufWriter::write(const Metadata& s)
size_t ProtobufWriter::write(Metadata const& s)
{
ProtobufFieldWriter w(this);
w.field_repeated_struct(1, s.stripeStats);
Expand Down Expand Up @@ -443,13 +443,13 @@ host_span<uint8_t const> OrcDecompressor::decompress_blocks(host_span<uint8_t co

metadata::metadata(datasource* const src, rmm::cuda_stream_view stream) : source(src)
{
const auto len = source->size();
const auto max_ps_size = std::min(len, static_cast<size_t>(256));
auto const len = source->size();
auto const max_ps_size = std::min(len, static_cast<size_t>(256));

// Read uncompressed postscript section (max 255 bytes + 1 byte for length)
auto buffer = source->host_read(len - max_ps_size, max_ps_size);
const size_t ps_length = buffer->data()[max_ps_size - 1];
const uint8_t* ps_data = &buffer->data()[max_ps_size - ps_length - 1];
size_t const ps_length = buffer->data()[max_ps_size - 1];
uint8_t const* ps_data = &buffer->data()[max_ps_size - ps_length - 1];
ProtobufReader(ps_data, ps_length).read(ps);
CUDF_EXPECTS(ps.footerLength + ps_length < len, "Invalid footer length");

Expand Down
81 changes: 46 additions & 35 deletions cpp/src/io/orc/orc.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -196,7 +196,7 @@ int constexpr encode_field_number(int field_number) noexcept
*/
class ProtobufReader {
public:
ProtobufReader(const uint8_t* base, size_t len) : m_base(base), m_cur(base), m_end(base + len) {}
ProtobufReader(uint8_t const* base, size_t len) : m_base(base), m_cur(base), m_end(base + len) {}

template <typename T>
void read(T& s)
Expand Down Expand Up @@ -241,40 +241,40 @@ class ProtobufReader {
template <typename T, typename... Operator>
void function_builder(T& s, size_t maxlen, std::tuple<Operator...>& op);

uint32_t read_field_size(const uint8_t* end);
uint32_t read_field_size(uint8_t const* end);

template <typename T, std::enable_if_t<std::is_integral_v<T>>* = nullptr>
void read_field(T& value, const uint8_t* end)
void read_field(T& value, uint8_t const* end)
{
value = get<T>();
}

template <typename T, std::enable_if_t<std::is_enum_v<T>>* = nullptr>
void read_field(T& value, const uint8_t* end)
void read_field(T& value, uint8_t const* end)
{
value = static_cast<T>(get<uint32_t>());
}

template <typename T, std::enable_if_t<std::is_same_v<T, std::string>>* = nullptr>
void read_field(T& value, const uint8_t* end)
void read_field(T& value, uint8_t const* end)
{
auto const size = read_field_size(end);
value.assign(reinterpret_cast<const char*>(m_cur), size);
value.assign(reinterpret_cast<char const*>(m_cur), size);
m_cur += size;
}

template <typename T, std::enable_if_t<std::is_same_v<T, std::vector<std::string>>>* = nullptr>
void read_field(T& value, const uint8_t* end)
void read_field(T& value, uint8_t const* end)
{
auto const size = read_field_size(end);
value.emplace_back(reinterpret_cast<const char*>(m_cur), size);
value.emplace_back(reinterpret_cast<char const*>(m_cur), size);
m_cur += size;
}

template <typename T,
std::enable_if_t<std::is_same_v<T, std::vector<typename T::value_type>> and
!std::is_same_v<std::string, typename T::value_type>>* = nullptr>
void read_field(T& value, const uint8_t* end)
void read_field(T& value, uint8_t const* end)
{
auto const size = read_field_size(end);
value.emplace_back();
Expand All @@ -283,29 +283,29 @@ class ProtobufReader {

template <typename T,
std::enable_if_t<std::is_same_v<T, std::optional<typename T::value_type>>>* = nullptr>
void read_field(T& value, const uint8_t* end)
void read_field(T& value, uint8_t const* end)
{
typename T::value_type contained_value;
read_field(contained_value, end);
value = std::optional<typename T::value_type>{std::move(contained_value)};
}

template <typename T>
auto read_field(T& value, const uint8_t* end) -> decltype(read(value, 0))
auto read_field(T& value, uint8_t const* end) -> decltype(read(value, 0))
{
auto const size = read_field_size(end);
read(value, size);
}

template <typename T, std::enable_if_t<std::is_floating_point_v<T>>* = nullptr>
void read_field(T& value, const uint8_t* end)
void read_field(T& value, uint8_t const* end)
{
memcpy(&value, m_cur, sizeof(T));
m_cur += sizeof(T);
}

template <typename T>
void read_packed_field(T& value, const uint8_t* end)
void read_packed_field(T& value, uint8_t const* end)
{
auto const len = get<uint32_t>();
auto const field_end = std::min(m_cur + len, end);
Expand All @@ -314,7 +314,7 @@ class ProtobufReader {
}

template <typename T>
void read_raw_field(T& value, const uint8_t* end)
void read_raw_field(T& value, uint8_t const* end)
{
auto const size = read_field_size(end);
value.emplace_back(m_cur, m_cur + size);
Expand All @@ -331,7 +331,7 @@ class ProtobufReader {
{
}

inline void operator()(ProtobufReader* pbr, const uint8_t* end)
inline void operator()(ProtobufReader* pbr, uint8_t const* end)
{
pbr->read_field(output_value, end);
}
Expand All @@ -347,7 +347,7 @@ class ProtobufReader {
{
}

inline void operator()(ProtobufReader* pbr, const uint8_t* end)
inline void operator()(ProtobufReader* pbr, uint8_t const* end)
{
pbr->read_packed_field(output_value, end);
}
Expand All @@ -363,15 +363,15 @@ class ProtobufReader {
{
}

inline void operator()(ProtobufReader* pbr, const uint8_t* end)
inline void operator()(ProtobufReader* pbr, uint8_t const* end)
{
pbr->read_raw_field(output_value, end);
}
};

const uint8_t* const m_base;
const uint8_t* m_cur;
const uint8_t* const m_end;
uint8_t const* const m_base;
uint8_t const* m_cur;
uint8_t const* const m_end;

public:
/**
Expand Down Expand Up @@ -477,13 +477,14 @@ inline int64_t ProtobufReader::get<int64_t>()
*/
class ProtobufWriter {
public:
ProtobufWriter() { m_buf = nullptr; }
ProtobufWriter(std::vector<uint8_t>* output) { m_buf = output; }
ProtobufWriter() : m_buf{std::make_unique<std::vector<uint8_t>>()} {}

uint32_t put_byte(uint8_t v)
{
m_buf->push_back(v);
return 1;
}

template <typename T>
uint32_t put_bytes(host_span<T const> values)
{
Expand All @@ -492,6 +493,7 @@ class ProtobufWriter {
m_buf->insert(m_buf->end(), values.begin(), values.end());
return values.size();
}

uint32_t put_uint(uint64_t v)
{
int l = 1;
Expand Down Expand Up @@ -519,6 +521,7 @@ class ProtobufWriter {
int64_t s = (v < 0);
return put_uint(((v ^ -s) << 1) + s);
}

void put_row_index_entry(int32_t present_blk,
int32_t present_ofs,
int32_t data_blk,
Expand All @@ -528,20 +531,28 @@ class ProtobufWriter {
TypeKind kind,
ColStatsBlob const* stats);

void resize(std::size_t bytes) { m_buf->resize(bytes); }

std::size_t size() const { return m_buf->size(); }
uint8_t const* data() { return m_buf->data(); }
vuule marked this conversation as resolved.
Show resolved Hide resolved

std::vector<uint8_t>& buffer() { return *m_buf; }
std::unique_ptr<std::vector<uint8_t>> release() { return std::move(m_buf); }
ttnghia marked this conversation as resolved.
Show resolved Hide resolved

public:
size_t write(const PostScript&);
size_t write(const FileFooter&);
size_t write(const StripeInformation&);
size_t write(const SchemaType&);
size_t write(const UserMetadataItem&);
size_t write(const StripeFooter&);
size_t write(const Stream&);
size_t write(const ColumnEncoding&);
size_t write(const StripeStatistics&);
size_t write(const Metadata&);
size_t write(PostScript const&);
size_t write(FileFooter const&);
size_t write(StripeInformation const&);
size_t write(SchemaType const&);
size_t write(UserMetadataItem const&);
size_t write(StripeFooter const&);
size_t write(Stream const&);
size_t write(ColumnEncoding const&);
size_t write(StripeStatistics const&);
size_t write(Metadata const&);

protected:
std::vector<uint8_t>* m_buf;
std::unique_ptr<std::vector<uint8_t>> m_buf;
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
struct ProtobufFieldWriter;
};

Expand Down Expand Up @@ -613,7 +624,7 @@ struct column_validity_info {
* convenience methods for initializing and accessing metadata.
*/
class metadata {
using OrcStripeInfo = std::pair<const StripeInformation*, const StripeFooter*>;
using OrcStripeInfo = std::pair<StripeInformation const*, StripeFooter const*>;

public:
struct stripe_source_mapping {
Expand Down
Loading