From 91f7943e61c75dffff4bcff382dd62b2be167978 Mon Sep 17 00:00:00 2001 From: Ed Seidl Date: Tue, 12 Sep 2023 18:28:58 -0700 Subject: [PATCH 01/37] add DELTA_BINARY_PACKED encoder --- cpp/src/io/parquet/delta_binary.cuh | 6 - cpp/src/io/parquet/delta_enc.cuh | 269 ++++++++ cpp/src/io/parquet/page_enc.cu | 983 ++++++++++++++++++++-------- cpp/src/io/parquet/parquet_gpu.hpp | 18 + cpp/tests/io/parquet_test.cpp | 18 +- 5 files changed, 999 insertions(+), 295 deletions(-) create mode 100644 cpp/src/io/parquet/delta_enc.cuh diff --git a/cpp/src/io/parquet/delta_binary.cuh b/cpp/src/io/parquet/delta_binary.cuh index 4fc8b9cfb8e..7aecc7f01e0 100644 --- a/cpp/src/io/parquet/delta_binary.cuh +++ b/cpp/src/io/parquet/delta_binary.cuh @@ -46,12 +46,6 @@ namespace cudf::io::parquet::gpu { // encoded with DELTA_LENGTH_BYTE_ARRAY encoding, which is a DELTA_BINARY_PACKED list of suffix // lengths, followed by the concatenated suffix data. -// TODO: The delta encodings use ULEB128 integers, but for now we're only -// using max 64 bits. Need to see what the performance impact is of using -// __int128_t rather than int64_t. -using uleb128_t = uint64_t; -using zigzag128_t = int64_t; - // we decode one mini-block at a time. max mini-block size seen is 64. constexpr int delta_rolling_buf_size = 128; diff --git a/cpp/src/io/parquet/delta_enc.cuh b/cpp/src/io/parquet/delta_enc.cuh new file mode 100644 index 00000000000..164849edd63 --- /dev/null +++ b/cpp/src/io/parquet/delta_enc.cuh @@ -0,0 +1,269 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "parquet_gpu.hpp" + +#include + +#include + +namespace cudf::io::parquet::gpu { + +namespace delta { + +inline __device__ void put_uleb128(uint8_t*& p, uleb128_t v) +{ + while (v > 0x7f) { + *p++ = v | 0x80; + v >>= 7; + } + *p++ = v; +} + +inline __device__ uint8_t* put_zz128(uint8_t*& p, zigzag128_t v) +{ + zigzag128_t s = (v < 0); + put_uleb128(p, (v ^ -s) * 2 + s); +} + +// a block size of 128, with 4 mini-blocks of 32 values each fits nicely without consuming +// too much shared memory. +constexpr int block_size = 128; +constexpr int num_mini_blocks = 4; +constexpr int values_per_mini_block = block_size / num_mini_blocks; +constexpr int buffer_size = 2 * block_size; + +using block_reduce = cub::BlockReduce; +using warp_reduce = cub::WarpReduce; +using index_scan = cub::BlockScan; + +constexpr int rolling_idx(int index) { return rolling_index(index); } + +// version of bit packer that can handle up to 64 bits values. +template +inline __device__ void bitpack_mini_block( + uint8_t* dst, T val, uint32_t count, uint8_t nbits, void* temp_space) +{ + // typing for atomicOr is annoying + using scratch_type = + std::conditional_t, unsigned long long, uint32_t>; + using cudf::detail::warp_size; + T constexpr mask = sizeof(T) * 8 - 1; + auto constexpr div = sizeof(T) * 8; + + auto const lane_id = threadIdx.x % warp_size; + auto const warp_id = threadIdx.x / warp_size; + + scratch_type* scratch = reinterpret_cast(temp_space) + warp_id * warp_size; + + // zero out scratch + scratch[lane_id] = 0; + __syncwarp(); + + // why use bit packing when there's no savings??? + if (nbits == div) { + if (lane_id < count) { + for (int i = 0; i < sizeof(T); i++) { + dst[lane_id * sizeof(T) + i] = val & 0xff; + if constexpr (sizeof(T) > 1) { val >>= 8; } + } + } + __syncwarp(); + return; + } + + if (lane_id <= count) { + // shift symbol left by up to mask bits + WideType v2 = val; + v2 <<= (lane_id * nbits) & mask; + + // Copy N bit word into two N/2 bit words while following C++ strict aliasing rules. + T v1[2]; + memcpy(&v1, &v2, sizeof(WideType)); + + // Atomically write result to scratch + if (v1[0]) { atomicOr(scratch + ((lane_id * nbits) / div), v1[0]); } + if (v1[1]) { atomicOr(scratch + ((lane_id * nbits) / div) + 1, v1[1]); } + } + __syncwarp(); + + // Copy scratch data to final destination + auto available_bytes = (count * nbits + 7) / 8; + + auto scratch_bytes = reinterpret_cast(scratch); + for (uint32_t i = lane_id; i < available_bytes; i += warp_size) { + dst[i] = scratch_bytes[i]; + } + __syncwarp(); +} + +} // namespace delta + +// Object used to turn a stream of integers into a DELTA_BINARY_PACKED stream. This takes as input +// 128 values with validity at a time, saving them until there are enough values for a block +// to be written. +// +// T can only be uint32_t or uint64_t since the DELTA_BINARY_PACKED encoding is only defined for +// INT32 and INT64 physical types +template +class DeltaBinaryPacker { + private: + // static_assert(std::is_same_v || std::is_same_v); + + uint8_t* _dst; // sink to dump encoded values to + size_type _current_idx; // index of first value in buffer + uint32_t _num_values; // total number of values to encode + size_type _values_in_buffer; // current number of values stored in _buffer + T* _buffer; // buffer to store values to be encoded + uint8_t _mb_bits[delta::num_mini_blocks]; // bitwidth for each mini-block + + // pointers to shared scratch memory for the warp and block scans/reduces + delta::index_scan::TempStorage* _scan_tmp; + delta::warp_reduce::TempStorage* _warp_tmp; + delta::block_reduce::TempStorage* _block_tmp; + + void* _bitpack_tmp; // pointer to shared scratch memory used in bitpacking + + // write the delta binary header. only call from thread 0 + inline __device__ void write_header(T first_value) + { + delta::put_uleb128(_dst, delta::block_size); + delta::put_uleb128(_dst, delta::num_mini_blocks); + delta::put_uleb128(_dst, _num_values); + delta::put_zz128(_dst, first_value); + } + + // write the block header. only call from thread 0 + inline __device__ void write_block_header(zigzag128_t block_min) + { + delta::put_zz128(_dst, block_min); + memcpy(_dst, _mb_bits, 4); + _dst += 4; + } + + public: + inline __device__ auto num_values() const { return _num_values; } + + // initialize the object. only call from thread 0 + inline __device__ void init(uint8_t* dest, uint32_t num_values, T* buffer, void* temp_storage) + { + _dst = dest; + _num_values = num_values; + _buffer = buffer; + _scan_tmp = reinterpret_cast(temp_storage); + _warp_tmp = reinterpret_cast(temp_storage); + _block_tmp = reinterpret_cast(temp_storage); + _bitpack_tmp = _buffer + delta::buffer_size; + _current_idx = 0; + _values_in_buffer = 0; + } + + // each thread calls this to add it's current value + inline __device__ void add_value(T value, bool is_valid) + { + // figure out the correct position for the given value + size_type const valid = is_valid; + size_type pos; + size_type num_valid; + delta::index_scan(*_scan_tmp).ExclusiveSum(valid, pos, num_valid); + + if (is_valid) { _buffer[delta::rolling_idx(pos + _current_idx + _values_in_buffer)] = value; } + __syncthreads(); + + if (threadIdx.x == 0) { + _values_in_buffer += num_valid; + // if first pass write header + if (_current_idx == 0) { + write_header(_buffer[0]); + _current_idx = 1; + _values_in_buffer -= 1; + } + } + __syncthreads(); + + if (_values_in_buffer >= delta::block_size) { flush(); } + } + + // called by each thread to flush data to the sink. + inline __device__ uint8_t const* flush() + { + using cudf::detail::warp_size; + __shared__ zigzag128_t block_min; + + int const t = threadIdx.x; + int const warp_id = t / warp_size; + int const lane_id = t % warp_size; + + if (_values_in_buffer <= 0) { return _dst; } + + // calculate delta for this thread + size_type const idx = _current_idx + t; + zigzag128_t const delta = + idx < _num_values ? _buffer[delta::rolling_idx(idx)] - _buffer[delta::rolling_idx(idx - 1)] + : std::numeric_limits::max(); + + // find min delta for the block + auto const min_delta = delta::block_reduce(*_block_tmp).Reduce(delta, cub::Min()); + + if (t == 0) { block_min = min_delta; } + __syncthreads(); + + // compute frame of reference for the block + uleb128_t const norm_delta = idx < _num_values ? delta - block_min : 0; + + // get max normalized delta for each warp, and use that to determine how many bits to use + // for the bitpacking of this warp + zigzag128_t const warp_max = + delta::warp_reduce(_warp_tmp[warp_id]).Reduce(norm_delta, cub::Max()); + + if (lane_id == 0) { _mb_bits[warp_id] = sizeof(zigzag128_t) * 8 - __clzll(warp_max); } + __syncthreads(); + + // write block header + if (t == 0) { write_block_header(block_min); } + __syncthreads(); + + // now each warp encodes it's data...can calculate starting offset with _mb_bits + uint8_t* mb_ptr = _dst; + switch (warp_id) { + case 3: mb_ptr += _mb_bits[2] * delta::values_per_mini_block / 8; [[fallthrough]]; + case 2: mb_ptr += _mb_bits[1] * delta::values_per_mini_block / 8; [[fallthrough]]; + case 1: mb_ptr += _mb_bits[0] * delta::values_per_mini_block / 8; + } + + // encoding happens here....will have to update pack literals to deal with larger numbers + auto const warp_idx = _current_idx + warp_id * delta::values_per_mini_block; + if (warp_idx < _num_values) { + auto const num_enc = min(delta::values_per_mini_block, _num_values - warp_idx); + delta::bitpack_mini_block( + mb_ptr, norm_delta, num_enc, _mb_bits[warp_id], _bitpack_tmp); + } + + // last lane updates global delta ptr + if (warp_id == delta::num_mini_blocks - 1 && lane_id == 0) { + _dst = mb_ptr + _mb_bits[warp_id] * delta::values_per_mini_block / 8; + _current_idx = min(warp_idx + delta::values_per_mini_block, _num_values); + _values_in_buffer = max(_values_in_buffer - delta::block_size, 0U); + } + __syncthreads(); + + return _dst; + } +}; + +} // namespace cudf::io::parquet::gpu diff --git a/cpp/src/io/parquet/page_enc.cu b/cpp/src/io/parquet/page_enc.cu index 0af561be8da..fe212ec6714 100644 --- a/cpp/src/io/parquet/page_enc.cu +++ b/cpp/src/io/parquet/page_enc.cu @@ -14,6 +14,7 @@ * limitations under the License. */ +#include "delta_enc.cuh" #include "parquet_gpu.cuh" #include @@ -21,6 +22,7 @@ #include #include #include +#include #include #include @@ -41,6 +43,8 @@ #include #include +#include + namespace cudf { namespace io { namespace parquet { @@ -50,7 +54,11 @@ namespace { using ::cudf::detail::device_2dspan; -constexpr uint32_t rle_buffer_size = (1 << 9); +constexpr int encode_block_size = 128; +constexpr int rle_buffer_size = 2 * encode_block_size; +constexpr int num_encode_warps = encode_block_size / cudf::detail::warp_size; + +constexpr int rolling_idx(int pos) { return rolling_index(pos); } // do not truncate statistics constexpr int32_t NO_TRUNC_STATS = 0; @@ -72,6 +80,7 @@ struct frag_init_state_s { PageFragment frag; }; +template struct page_enc_state_s { uint8_t* cur; //!< current output ptr uint8_t* rle_out; //!< current RLE write ptr @@ -84,12 +93,11 @@ struct page_enc_state_s { uint32_t rle_rpt_count; uint32_t page_start_val; uint32_t chunk_start_val; - volatile uint32_t rpt_map[4]; - volatile uint32_t scratch_red[32]; + volatile uint32_t rpt_map[num_encode_warps]; EncPage page; EncColumnChunk ck; parquet_column_device_view col; - uint32_t vals[rle_buffer_size]; + uint32_t vals[rle_buf_size]; }; /** @@ -239,6 +247,49 @@ struct BitwiseOr { } }; +// T is the parquet physical type +// W is double the bitwidth of T +// I is the column type from the input table +// F is a function that computes validity and the src index for a given input position +template +struct delta_enc { + page_enc_state_s<0>* s; + uint32_t valid_count; + F& f; + uint64_t* buffer; + void* temp_space; + + __device__ uint8_t const* encode() + { + __shared__ DeltaBinaryPacker packer; + + auto const t = threadIdx.x; + + if (t == 0) { packer.init(s->cur, valid_count, reinterpret_cast(buffer), temp_space); } + __syncthreads(); + + // FIXME int the plain encoder the scaling is a little different for INT32 than INT64. + // might need to patch this up some. + int32_t const scale = s->col.ts_scale == 0 ? 1 : s->col.ts_scale; + for (uint32_t cur_val_idx = 0; cur_val_idx < s->page.num_leaf_values;) { + uint32_t nvals = min(s->page.num_leaf_values - cur_val_idx, delta::block_size); + + auto [is_valid, val_idx] = f(cur_val_idx); + cur_val_idx += nvals; + + T v = s->col.leaf_column->element(val_idx); + if (scale < 0) { + v /= -scale; + } else { + v *= scale; + } + packer.add_value(v, is_valid); + } + + return packer.flush(); + } +}; + } // anonymous namespace // blockDim {512,1,1} @@ -326,6 +377,31 @@ __global__ void __launch_bounds__(128) } } +__device__ size_t delta_data_len(parquet::Type physical_type, + cudf::type_id type_id, + uint32_t num_values) +{ + auto const dtype_len_out = physical_type_len(physical_type, type_id); + auto const dtype_len = [&]() -> uint32_t { + if (physical_type == INT32) { return int32_logical_len(type_id); } + if (physical_type == INT96) { return sizeof(int64_t); } + return dtype_len_out; + }(); + + auto const vals_per_block = delta::block_size; + size_t const num_blocks = util::div_rounding_up_unsafe(num_values, vals_per_block); + // need max dtype_len_in + 1 bytes for min_delta + // one byte per mini block for the bitwidth + // and block_size * dtype_len_in bytes for the actual encoded data + auto const block_size = dtype_len + 1 + delta::num_mini_blocks + vals_per_block * dtype_len; + + // delta header is 2 bytes for the block_size, 1 byte for number of mini-blocks, + // max 5 bytes for number of values, and max dtype_len_in + 1 for first value. + auto const header_size = 2 + 1 + 5 + dtype_len + 1; + + return header_size + num_blocks * block_size; +} + // blockDim {128,1,1} __global__ void __launch_bounds__(128) gpuInitPages(device_2dspan chunks, @@ -357,6 +433,14 @@ __global__ void __launch_bounds__(128) page_g = {}; } __syncthreads(); + + // if writing delta encoded values, we're going to need to know the data length to get a guess + // at the worst case number of bytes needed to encode. + auto const physical_type = col_g.physical_type; + auto const type_id = col_g.leaf_column->type().id(); + auto const is_use_delta = + write_v2_headers && !ck_g.use_dictionary && (physical_type == INT32 || physical_type == INT64); + if (t < 32) { uint32_t fragments_in_chunk = 0; uint32_t rows_in_page = 0; @@ -406,9 +490,12 @@ __global__ void __launch_bounds__(128) } __syncwarp(); if (t == 0) { - if (not pages.empty()) pages[ck_g.first_page] = page_g; - if (not page_sizes.empty()) page_sizes[ck_g.first_page] = page_g.max_data_size; - if (page_grstats) page_grstats[ck_g.first_page] = pagestats_g; + if (not pages.empty()) { + page_g.kernel_mask = ENC_MASK_PLAIN; + pages[ck_g.first_page] = page_g; + } + if (not page_sizes.empty()) { page_sizes[ck_g.first_page] = page_g.max_data_size; } + if (page_grstats) { page_grstats[ck_g.first_page] = pagestats_g; } } num_pages = 1; } @@ -508,7 +595,12 @@ __global__ void __launch_bounds__(128) page_g.num_values = values_in_page; auto const def_level_size = max_RLE_page_size(col_g.num_def_level_bits(), values_in_page); auto const rep_level_size = max_RLE_page_size(col_g.num_rep_level_bits(), values_in_page); - auto const max_data_size = page_size + def_level_size + rep_level_size + rle_pad; + // get a different bound if using delta encoding + if (is_use_delta) { + page_size = + max(page_size, delta_data_len(physical_type, type_id, page_g.num_leaf_values)); + } + auto const max_data_size = page_size + def_level_size + rep_level_size + rle_pad; // page size must fit in 32-bit signed integer if (max_data_size > std::numeric_limits::max()) { CUDF_UNREACHABLE("page size exceeds maximum for i32"); @@ -528,7 +620,16 @@ __global__ void __launch_bounds__(128) } __syncwarp(); if (t == 0) { - if (not pages.empty()) { pages[ck_g.first_page + num_pages] = page_g; } + if (not pages.empty()) { + if (is_use_delta) { + page_g.kernel_mask = ENC_MASK_DELTA_BINARY; + } else if (ck_g.use_dictionary || physical_type == BOOLEAN) { + page_g.kernel_mask = ENC_MASK_DICTIONARY; + } else { + page_g.kernel_mask = ENC_MASK_PLAIN; + } + pages[ck_g.first_page + num_pages] = page_g; + } if (not page_sizes.empty()) { page_sizes[ck_g.first_page + num_pages] = page_g.max_data_size; } @@ -791,9 +892,14 @@ inline __device__ void PackLiterals( * @param[in] flush nonzero if last batch in block * @param[in] t thread id (0..127) */ +template static __device__ void RleEncode( - page_enc_state_s* s, uint32_t numvals, uint32_t nbits, uint32_t flush, uint32_t t) + state_buf* s, uint32_t numvals, uint32_t nbits, uint32_t flush, uint32_t t) { + using cudf::detail::warp_size; + auto const lane_id = t % warp_size; + auto const warp_id = t / warp_size; + uint32_t rle_pos = s->rle_pos; uint32_t rle_run = s->rle_run; @@ -801,20 +907,20 @@ static __device__ void RleEncode( uint32_t pos = rle_pos + t; if (rle_run > 0 && !(rle_run & 1)) { // Currently in a long repeat run - uint32_t mask = ballot(pos < numvals && s->vals[pos & (rle_buffer_size - 1)] == s->run_val); + uint32_t mask = ballot(pos < numvals && s->vals[rolling_idx(pos)] == s->run_val); uint32_t rle_rpt_count, max_rpt_count; - if (!(t & 0x1f)) { s->rpt_map[t >> 5] = mask; } + if (lane_id == 0) { s->rpt_map[warp_id] = mask; } __syncthreads(); - if (t < 32) { + if (t < warp_size) { uint32_t c32 = ballot(t >= 4 || s->rpt_map[t] != 0xffff'ffffu); - if (!t) { + if (t == 0) { uint32_t last_idx = __ffs(c32) - 1; s->rle_rpt_count = - last_idx * 32 + ((last_idx < 4) ? __ffs(~s->rpt_map[last_idx]) - 1 : 0); + last_idx * warp_size + ((last_idx < 4) ? __ffs(~s->rpt_map[last_idx]) - 1 : 0); } } __syncthreads(); - max_rpt_count = min(numvals - rle_pos, 128); + max_rpt_count = min(numvals - rle_pos, encode_block_size); rle_rpt_count = s->rle_rpt_count; rle_run += rle_rpt_count << 1; rle_pos += rle_rpt_count; @@ -831,17 +937,17 @@ static __device__ void RleEncode( } } else { // New run or in a literal run - uint32_t v0 = s->vals[pos & (rle_buffer_size - 1)]; - uint32_t v1 = s->vals[(pos + 1) & (rle_buffer_size - 1)]; + uint32_t v0 = s->vals[rolling_idx(pos)]; + uint32_t v1 = s->vals[rolling_idx(pos + 1)]; uint32_t mask = ballot(pos + 1 < numvals && v0 == v1); - uint32_t maxvals = min(numvals - rle_pos, 128); + uint32_t maxvals = min(numvals - rle_pos, encode_block_size); uint32_t rle_lit_count, rle_rpt_count; - if (!(t & 0x1f)) { s->rpt_map[t >> 5] = mask; } + if (lane_id == 0) { s->rpt_map[warp_id] = mask; } __syncthreads(); - if (t < 32) { + if (t < warp_size) { // Repeat run can only start on a multiple of 8 values - uint32_t idx8 = (t * 8) >> 5; - uint32_t pos8 = (t * 8) & 0x1f; + uint32_t idx8 = (t * 8) / warp_size; + uint32_t pos8 = (t * 8) % warp_size; uint32_t m0 = (idx8 < 4) ? s->rpt_map[idx8] : 0; uint32_t m1 = (idx8 < 3) ? s->rpt_map[idx8 + 1] : 0; uint32_t needed_mask = kRleRunMask[nbits - 1]; @@ -850,8 +956,8 @@ static __device__ void RleEncode( uint32_t rle_run_start = (mask != 0) ? min((__ffs(mask) - 1) * 8, maxvals) : maxvals; uint32_t rpt_len = 0; if (rle_run_start < maxvals) { - uint32_t idx_cur = rle_run_start >> 5; - uint32_t idx_ofs = rle_run_start & 0x1f; + uint32_t idx_cur = rle_run_start / warp_size; + uint32_t idx_ofs = rle_run_start % warp_size; while (idx_cur < 4) { m0 = (idx_cur < 4) ? s->rpt_map[idx_cur] : 0; m1 = (idx_cur < 3) ? s->rpt_map[idx_cur + 1] : 0; @@ -860,7 +966,7 @@ static __device__ void RleEncode( rpt_len += __ffs(mask) - 1; break; } - rpt_len += 32; + rpt_len += warp_size; idx_cur++; } } @@ -931,17 +1037,15 @@ static __device__ void RleEncode( * @param[in] flush nonzero if last batch in block * @param[in] t thread id (0..127) */ -static __device__ void PlainBoolEncode(page_enc_state_s* s, - uint32_t numvals, - uint32_t flush, - uint32_t t) +template +static __device__ void PlainBoolEncode(state_buf* s, uint32_t numvals, uint32_t flush, uint32_t t) { uint32_t rle_pos = s->rle_pos; uint8_t* dst = s->rle_out; while (rle_pos < numvals) { uint32_t pos = rle_pos + t; - uint32_t v = (pos < numvals) ? s->vals[pos & (rle_buffer_size - 1)] : 0; + uint32_t v = (pos < numvals) ? s->vals[rolling_idx(pos)] : 0; uint32_t n = min(numvals - rle_pos, 128); uint32_t nbytes = (n + ((flush) ? 7 : 0)) >> 3; if (!nbytes) { break; } @@ -995,28 +1099,22 @@ __device__ auto julian_days_with_time(int64_t v) return std::make_pair(dur_time_of_day_nanos, julian_days); } +// this has been split out into its own kernel because of the amount of shared memory required +// for the state buffer. encode kernels that don't use the RLE buffer can get started while +// the level data is encoded. +// FIXME: what should the args to launch_bounds be now? // blockDim(128, 1, 1) -template -__global__ void __launch_bounds__(128, 8) - gpuEncodePages(device_span pages, - device_span> comp_in, - device_span> comp_out, - device_span comp_results, - bool write_v2_headers) +template +__global__ void __launch_bounds__(block_size, 8) + gpuEncodePageLevels(device_span pages, bool write_v2_headers) { - __shared__ __align__(8) page_enc_state_s state_g; - using block_reduce = cub::BlockReduce; - using block_scan = cub::BlockScan; - __shared__ union { - typename block_reduce::TempStorage reduce_storage; - typename block_scan::TempStorage scan_storage; - } temp_storage; + __shared__ __align__(8) page_enc_state_s state_g; - page_enc_state_s* const s = &state_g; - auto const t = threadIdx.x; + auto* const s = &state_g; + uint32_t const t = threadIdx.x; if (t == 0) { - state_g = page_enc_state_s{}; + state_g = page_enc_state_s{}; s->page = pages[blockIdx.x]; s->ck = *s->page.chunk; s->col = *s->ck.col_desc; @@ -1029,6 +1127,8 @@ __global__ void __launch_bounds__(128, 8) } __syncthreads(); + if ((s->page.kernel_mask & kernel_mask) == 0) { return; } + auto const is_v2 = s->page.page_type == PageType::DATA_PAGE_V2; // Encode Repetition and Definition levels @@ -1081,7 +1181,7 @@ __global__ void __launch_bounds__(128, 8) } while (is_col_struct); return def; }(); - s->vals[(rle_numvals + t) & (rle_buffer_size - 1)] = def_lvl; + s->vals[rolling_idx(rle_numvals + t)] = def_lvl; __syncthreads(); rle_numvals += nrows; RleEncode(s, rle_numvals, def_lvl_bits, (rle_numvals == s->page.num_rows), t); @@ -1091,13 +1191,12 @@ __global__ void __launch_bounds__(128, 8) uint8_t* const cur = s->cur; uint8_t* const rle_out = s->rle_out; uint32_t const rle_bytes = static_cast(rle_out - cur) - (is_v2 ? 0 : 4); - if (is_v2 && t == 0) { + if (not is_v2 && t < 4) { cur[t] = rle_bytes >> (t * 8); } + __syncwarp(); + if (t == 0) { + s->cur = rle_out; s->page.def_lvl_bytes = rle_bytes; - } else if (not is_v2 && t < 4) { - cur[t] = rle_bytes >> (t * 8); } - __syncwarp(); - if (t == 0) { s->cur = rle_out; } } } } else if (s->page.page_type != PageType::DICTIONARY_PAGE && @@ -1124,7 +1223,7 @@ __global__ void __launch_bounds__(128, 8) uint32_t idx = page_first_val_idx + rle_numvals + t; uint32_t lvl_val = (rle_numvals + t < s->page.num_values && idx < col_last_val_idx) ? lvl_val_data[idx] : 0; - s->vals[(rle_numvals + t) & (rle_buffer_size - 1)] = lvl_val; + s->vals[rolling_idx(rle_numvals + t)] = lvl_val; __syncthreads(); rle_numvals += nvals; RleEncode(s, rle_numvals, nbits, (rle_numvals == s->page.num_values), t); @@ -1134,19 +1233,109 @@ __global__ void __launch_bounds__(128, 8) uint8_t* const cur = s->cur; uint8_t* const rle_out = s->rle_out; uint32_t const rle_bytes = static_cast(rle_out - cur) - (is_v2 ? 0 : 4); - if (is_v2 && t == 0) { + if (not is_v2 && t < 4) { cur[t] = rle_bytes >> (t * 8); } + __syncwarp(); + if (t == 0) { + s->cur = rle_out; lvl_bytes = rle_bytes; - } else if (not is_v2 && t < 4) { - cur[t] = rle_bytes >> (t * 8); } - __syncwarp(); - if (t == 0) { s->cur = rle_out; } } }; encode_levels(s->col.rep_values, s->col.num_rep_level_bits(), s->page.rep_lvl_bytes); __syncthreads(); encode_levels(s->col.def_values, s->col.num_def_level_bits(), s->page.def_lvl_bytes); } + + if (t == 0) { pages[blockIdx.x] = s->page; } +} + +template +__device__ void finish_page_encode(state_buf* s, + uint32_t valid_count, + uint8_t const* end_ptr, + device_span pages, + device_span> comp_in, + device_span> comp_out, + device_span comp_results, + bool write_v2_headers) +{ + auto const t = threadIdx.x; + + // V2 does not compress rep and def level data + size_t const skip_comp_size = + write_v2_headers ? s->page.def_lvl_bytes + s->page.rep_lvl_bytes : 0; + + if (t == 0) { + // only need num_nulls for v2 data page headers + if (write_v2_headers) { s->page.num_nulls = s->page.num_values - valid_count; } + uint8_t const* const base = s->page.page_data + s->page.max_hdr_size; + auto const actual_data_size = static_cast(end_ptr - base); + if (actual_data_size > s->page.max_data_size) { + printf("data corruption %d %d\n", actual_data_size, s->page.max_data_size); + CUDF_UNREACHABLE("detected possible page data corruption"); + } + s->page.max_data_size = actual_data_size; + if (not comp_in.empty()) { + comp_in[blockIdx.x] = {base + skip_comp_size, actual_data_size - skip_comp_size}; + comp_out[blockIdx.x] = {s->page.compressed_data + s->page.max_hdr_size + skip_comp_size, + 0}; // size is unused + } + pages[blockIdx.x] = s->page; + if (not comp_results.empty()) { + comp_results[blockIdx.x] = {0, compression_status::FAILURE}; + pages[blockIdx.x].comp_res = &comp_results[blockIdx.x]; + } + } + + // copy uncompressed bytes over + if (skip_comp_size != 0 && not comp_in.empty()) { + uint8_t* src = s->page.page_data + s->page.max_hdr_size; + uint8_t* dst = s->page.compressed_data + s->page.max_hdr_size; + for (int i = t; i < skip_comp_size; i += block_size) { + dst[i] = src[i]; + } + } +} + +// FIXME: what should the args to launch_bounds be now? +// blockDim(128, 1, 1) +template +__global__ void __launch_bounds__(block_size, 8) + gpuEncodePages(device_span pages, + device_span> comp_in, + device_span> comp_out, + device_span comp_results, + bool write_v2_headers) +{ + __shared__ __align__(8) page_enc_state_s<0> state_g; + using block_reduce = cub::BlockReduce; + using block_scan = cub::BlockScan; + __shared__ union { + typename block_reduce::TempStorage reduce_storage; + typename block_scan::TempStorage scan_storage; + } temp_storage; + + auto* const s = &state_g; + uint32_t t = threadIdx.x; + + if (t == 0) { + state_g = page_enc_state_s<0>{}; + s->page = pages[blockIdx.x]; + s->ck = *s->page.chunk; + s->col = *s->ck.col_desc; + s->rle_len_pos = nullptr; + // get s->cur back to where it was at the end of encoding the rep and def level data + s->cur = + s->page.page_data + s->page.max_hdr_size + s->page.def_lvl_bytes + s->page.rep_lvl_bytes; + if (s->page.page_type == PageType::DATA_PAGE) { + if (s->col.num_def_level_bits() != 0) { s->cur += 4; } + if (s->col.num_rep_level_bits() != 0) { s->cur += 4; } + } + } + __syncthreads(); + + if ((s->page.kernel_mask & ENC_MASK_PLAIN) == 0) { return; } + // Encode data values __syncthreads(); auto const physical_type = s->col.physical_type; @@ -1158,10 +1347,6 @@ __global__ void __launch_bounds__(128, 8) return dtype_len_out; }(); - auto const dict_bits = (physical_type == BOOLEAN) ? 1 - : (s->ck.use_dictionary and s->page.page_type != PageType::DICTIONARY_PAGE) - ? s->ck.dict_rle_bits - : -1; if (t == 0) { uint8_t* dst = s->cur; s->rle_run = 0; @@ -1170,219 +1355,314 @@ __global__ void __launch_bounds__(128, 8) s->rle_out = dst; s->page.encoding = determine_encoding(s->page.page_type, physical_type, s->ck.use_dictionary, write_v2_headers); - if (dict_bits >= 0 && physical_type != BOOLEAN) { - dst[0] = dict_bits; - s->rle_out = dst + 1; - } else if (is_v2 && physical_type == BOOLEAN) { - // save space for RLE length. we don't know the total length yet. - s->rle_out = dst + RLE_LENGTH_FIELD_LEN; - s->rle_len_pos = dst; - } s->page_start_val = row_to_value_idx(s->page.start_row, s->col); s->chunk_start_val = row_to_value_idx(s->ck.start_row, s->col); } __syncthreads(); + uint32_t num_valid = 0; for (uint32_t cur_val_idx = 0; cur_val_idx < s->page.num_leaf_values;) { - uint32_t nvals = min(s->page.num_leaf_values - cur_val_idx, 128); + uint32_t nvals = min(s->page.num_leaf_values - cur_val_idx, block_size); uint32_t len, pos; auto [is_valid, val_idx] = [&]() { uint32_t val_idx; uint32_t is_valid; - size_type val_idx_in_block = cur_val_idx + t; + size_type const val_idx_in_block = cur_val_idx + t; if (s->page.page_type == PageType::DICTIONARY_PAGE) { val_idx = val_idx_in_block; is_valid = (val_idx < s->page.num_leaf_values); if (is_valid) { val_idx = s->ck.dict_data[val_idx]; } } else { - size_type val_idx_in_leaf_col = s->page_start_val + val_idx_in_block; + size_type const val_idx_in_leaf_col = s->page_start_val + val_idx_in_block; is_valid = (val_idx_in_leaf_col < s->col.leaf_column->size() && val_idx_in_block < s->page.num_leaf_values) ? s->col.leaf_column->is_valid(val_idx_in_leaf_col) : 0; - val_idx = - (s->ck.use_dictionary) ? val_idx_in_leaf_col - s->chunk_start_val : val_idx_in_leaf_col; + val_idx = val_idx_in_leaf_col; } return std::make_tuple(is_valid, val_idx); }(); - if (is_valid) num_valid++; - + if (is_valid) { num_valid++; } cur_val_idx += nvals; - if (dict_bits >= 0) { - // Dictionary encoding - if (dict_bits > 0) { - uint32_t rle_numvals; - uint32_t rle_numvals_in_block; - block_scan(temp_storage.scan_storage).ExclusiveSum(is_valid, pos, rle_numvals_in_block); - rle_numvals = s->rle_numvals; - if (is_valid) { - uint32_t v; - if (physical_type == BOOLEAN) { - v = s->col.leaf_column->element(val_idx); - } else { - v = s->ck.dict_index[val_idx]; - } - s->vals[(rle_numvals + pos) & (rle_buffer_size - 1)] = v; - } - rle_numvals += rle_numvals_in_block; - __syncthreads(); - if (!is_v2 && physical_type == BOOLEAN) { - PlainBoolEncode(s, rle_numvals, (cur_val_idx == s->page.num_leaf_values), t); - } else { - RleEncode(s, rle_numvals, dict_bits, (cur_val_idx == s->page.num_leaf_values), t); + + // Non-dictionary encoding + uint8_t* dst = s->cur; + + if (is_valid) { + len = dtype_len_out; + if (physical_type == BYTE_ARRAY) { + if (type_id == type_id::STRING) { + len += s->col.leaf_column->element(val_idx).size_bytes(); + } else if (s->col.output_as_byte_array && type_id == type_id::LIST) { + len += + get_element(*s->col.leaf_column, val_idx).size_bytes(); } - __syncthreads(); } - if (t == 0) { s->cur = s->rle_out; } - __syncthreads(); } else { - // Non-dictionary encoding - uint8_t* dst = s->cur; - - if (is_valid) { - len = dtype_len_out; - if (physical_type == BYTE_ARRAY) { - if (type_id == type_id::STRING) { - len += s->col.leaf_column->element(val_idx).size_bytes(); - } else if (s->col.output_as_byte_array && type_id == type_id::LIST) { - len += - get_element(*s->col.leaf_column, val_idx).size_bytes(); + len = 0; + } + uint32_t total_len = 0; + block_scan(temp_storage.scan_storage).ExclusiveSum(len, pos, total_len); + __syncthreads(); + if (t == 0) { s->cur = dst + total_len; } + if (is_valid) { + switch (physical_type) { + case INT32: [[fallthrough]]; + case FLOAT: { + auto const v = [dtype_len = dtype_len_in, + idx = val_idx, + col = s->col.leaf_column, + scale = s->col.ts_scale == 0 ? 1 : s->col.ts_scale]() -> int32_t { + switch (dtype_len) { + case 8: return col->element(idx) * scale; + case 4: return col->element(idx) * scale; + case 2: return col->element(idx) * scale; + default: return col->element(idx) * scale; + } + }(); + + dst[pos + 0] = v; + dst[pos + 1] = v >> 8; + dst[pos + 2] = v >> 16; + dst[pos + 3] = v >> 24; + } break; + case INT64: { + int64_t v = s->col.leaf_column->element(val_idx); + int32_t ts_scale = s->col.ts_scale; + if (ts_scale != 0) { + if (ts_scale < 0) { + v /= -ts_scale; + } else { + v *= ts_scale; + } + } + dst[pos + 0] = v; + dst[pos + 1] = v >> 8; + dst[pos + 2] = v >> 16; + dst[pos + 3] = v >> 24; + dst[pos + 4] = v >> 32; + dst[pos + 5] = v >> 40; + dst[pos + 6] = v >> 48; + dst[pos + 7] = v >> 56; + } break; + case INT96: { + int64_t v = s->col.leaf_column->element(val_idx); + int32_t ts_scale = s->col.ts_scale; + if (ts_scale != 0) { + if (ts_scale < 0) { + v /= -ts_scale; + } else { + v *= ts_scale; + } } - } - } else { - len = 0; - } - uint32_t total_len = 0; - block_scan(temp_storage.scan_storage).ExclusiveSum(len, pos, total_len); - __syncthreads(); - if (t == 0) { s->cur = dst + total_len; } - if (is_valid) { - switch (physical_type) { - case INT32: [[fallthrough]]; - case FLOAT: { - auto const v = [dtype_len = dtype_len_in, - idx = val_idx, - col = s->col.leaf_column, - scale = s->col.ts_scale == 0 ? 1 : s->col.ts_scale]() -> int32_t { - switch (dtype_len) { - case 8: return col->element(idx) * scale; - case 4: return col->element(idx) * scale; - case 2: return col->element(idx) * scale; - default: return col->element(idx) * scale; - } - }(); - dst[pos + 0] = v; - dst[pos + 1] = v >> 8; - dst[pos + 2] = v >> 16; - dst[pos + 3] = v >> 24; - } break; - case INT64: { - int64_t v = s->col.leaf_column->element(val_idx); - int32_t ts_scale = s->col.ts_scale; - if (ts_scale != 0) { - if (ts_scale < 0) { - v /= -ts_scale; - } else { - v *= ts_scale; - } + auto const [last_day_nanos, julian_days] = [&] { + using namespace cuda::std::chrono; + switch (s->col.leaf_column->type().id()) { + case type_id::TIMESTAMP_SECONDS: + case type_id::TIMESTAMP_MILLISECONDS: { + return julian_days_with_time(v); + } break; + case type_id::TIMESTAMP_MICROSECONDS: + case type_id::TIMESTAMP_NANOSECONDS: { + return julian_days_with_time(v); + } break; } - dst[pos + 0] = v; - dst[pos + 1] = v >> 8; - dst[pos + 2] = v >> 16; - dst[pos + 3] = v >> 24; - dst[pos + 4] = v >> 32; - dst[pos + 5] = v >> 40; - dst[pos + 6] = v >> 48; - dst[pos + 7] = v >> 56; - } break; - case INT96: { - int64_t v = s->col.leaf_column->element(val_idx); - int32_t ts_scale = s->col.ts_scale; - if (ts_scale != 0) { - if (ts_scale < 0) { - v /= -ts_scale; - } else { - v *= ts_scale; - } + return julian_days_with_time(0); + }(); + + // the 12 bytes of fixed length data. + v = last_day_nanos.count(); + dst[pos + 0] = v; + dst[pos + 1] = v >> 8; + dst[pos + 2] = v >> 16; + dst[pos + 3] = v >> 24; + dst[pos + 4] = v >> 32; + dst[pos + 5] = v >> 40; + dst[pos + 6] = v >> 48; + dst[pos + 7] = v >> 56; + uint32_t w = julian_days.count(); + dst[pos + 8] = w; + dst[pos + 9] = w >> 8; + dst[pos + 10] = w >> 16; + dst[pos + 11] = w >> 24; + } break; + + case DOUBLE: { + auto v = s->col.leaf_column->element(val_idx); + memcpy(dst + pos, &v, 8); + } break; + case BYTE_ARRAY: { + auto const bytes = [](cudf::type_id const type_id, + column_device_view const* leaf_column, + uint32_t const val_idx) -> void const* { + switch (type_id) { + case type_id::STRING: + return reinterpret_cast( + leaf_column->element(val_idx).data()); + case type_id::LIST: + return reinterpret_cast( + get_element(*(leaf_column), val_idx).data()); + default: CUDF_UNREACHABLE("invalid type id for byte array writing!"); } + }(type_id, s->col.leaf_column, val_idx); + uint32_t v = len - 4; // string length + dst[pos + 0] = v; + dst[pos + 1] = v >> 8; + dst[pos + 2] = v >> 16; + dst[pos + 3] = v >> 24; + if (v != 0) memcpy(dst + pos + 4, bytes, v); + } break; + case FIXED_LEN_BYTE_ARRAY: { + if (type_id == type_id::DECIMAL128) { + // When using FIXED_LEN_BYTE_ARRAY for decimals, the rep is encoded in big-endian + auto const v = s->col.leaf_column->element(val_idx).value(); + auto const v_char_ptr = reinterpret_cast(&v); + thrust::copy(thrust::seq, + thrust::make_reverse_iterator(v_char_ptr + sizeof(v)), + thrust::make_reverse_iterator(v_char_ptr), + dst + pos); + } + } break; + } + } + __syncthreads(); + } - auto const [last_day_nanos, julian_days] = [&] { - using namespace cuda::std::chrono; - switch (s->col.leaf_column->type().id()) { - case type_id::TIMESTAMP_SECONDS: - case type_id::TIMESTAMP_MILLISECONDS: { - return julian_days_with_time(v); - } break; - case type_id::TIMESTAMP_MICROSECONDS: - case type_id::TIMESTAMP_NANOSECONDS: { - return julian_days_with_time(v); - } break; - } - return julian_days_with_time(0); - }(); - - // the 12 bytes of fixed length data. - v = last_day_nanos.count(); - dst[pos + 0] = v; - dst[pos + 1] = v >> 8; - dst[pos + 2] = v >> 16; - dst[pos + 3] = v >> 24; - dst[pos + 4] = v >> 32; - dst[pos + 5] = v >> 40; - dst[pos + 6] = v >> 48; - dst[pos + 7] = v >> 56; - uint32_t w = julian_days.count(); - dst[pos + 8] = w; - dst[pos + 9] = w >> 8; - dst[pos + 10] = w >> 16; - dst[pos + 11] = w >> 24; - } break; + uint32_t const valid_count = block_reduce(temp_storage.reduce_storage).Sum(num_valid); - case DOUBLE: { - auto v = s->col.leaf_column->element(val_idx); - memcpy(dst + pos, &v, 8); - } break; - case BYTE_ARRAY: { - auto const bytes = [](cudf::type_id const type_id, - column_device_view const* leaf_column, - uint32_t const val_idx) -> void const* { - switch (type_id) { - case type_id::STRING: - return reinterpret_cast( - leaf_column->element(val_idx).data()); - case type_id::LIST: - return reinterpret_cast( - get_element(*(leaf_column), val_idx).data()); - default: CUDF_UNREACHABLE("invalid type id for byte array writing!"); - } - }(type_id, s->col.leaf_column, val_idx); - uint32_t v = len - 4; // string length - dst[pos + 0] = v; - dst[pos + 1] = v >> 8; - dst[pos + 2] = v >> 16; - dst[pos + 3] = v >> 24; - if (v != 0) memcpy(dst + pos + 4, bytes, v); - } break; - case FIXED_LEN_BYTE_ARRAY: { - if (type_id == type_id::DECIMAL128) { - // When using FIXED_LEN_BYTE_ARRAY for decimals, the rep is encoded in big-endian - auto const v = s->col.leaf_column->element(val_idx).value(); - auto const v_char_ptr = reinterpret_cast(&v); - thrust::copy(thrust::seq, - thrust::make_reverse_iterator(v_char_ptr + sizeof(v)), - thrust::make_reverse_iterator(v_char_ptr), - dst + pos); - } - } break; + finish_page_encode( + s, valid_count, s->cur, pages, comp_in, comp_out, comp_results, write_v2_headers); +} + +// FIXME: what should the args to launch_bounds be now? +// blockDim(128, 1, 1) +template +__global__ void __launch_bounds__(block_size, 8) + gpuEncodeDictPages(device_span pages, + device_span> comp_in, + device_span> comp_out, + device_span comp_results, + bool write_v2_headers) +{ + __shared__ __align__(8) page_enc_state_s state_g; + using block_reduce = cub::BlockReduce; + using block_scan = cub::BlockScan; + __shared__ union { + typename block_reduce::TempStorage reduce_storage; + typename block_scan::TempStorage scan_storage; + } temp_storage; + + auto* const s = &state_g; + uint32_t t = threadIdx.x; + + if (t == 0) { + state_g = page_enc_state_s{}; + s->page = pages[blockIdx.x]; + s->ck = *s->page.chunk; + s->col = *s->ck.col_desc; + s->rle_len_pos = nullptr; + // get s->cur back to where it was at the end of encoding the rep and def level data + s->cur = + s->page.page_data + s->page.max_hdr_size + s->page.def_lvl_bytes + s->page.rep_lvl_bytes; + if (s->page.page_type == PageType::DATA_PAGE) { + if (s->col.num_def_level_bits() != 0) { s->cur += 4; } + if (s->col.num_rep_level_bits() != 0) { s->cur += 4; } + } + } + __syncthreads(); + + if ((s->page.kernel_mask & ENC_MASK_DICTIONARY) == 0) { return; } + + // Encode data values + __syncthreads(); + auto const physical_type = s->col.physical_type; + auto const type_id = s->col.leaf_column->type().id(); + auto const dtype_len_out = physical_type_len(physical_type, type_id); + auto const dtype_len_in = [&]() -> uint32_t { + if (physical_type == INT32) { return int32_logical_len(type_id); } + if (physical_type == INT96) { return sizeof(int64_t); } + return dtype_len_out; + }(); + + // TODO assert dict_bits >= 0 + auto const dict_bits = (physical_type == BOOLEAN) ? 1 + : (s->ck.use_dictionary and s->page.page_type != PageType::DICTIONARY_PAGE) + ? s->ck.dict_rle_bits + : -1; + if (t == 0) { + uint8_t* dst = s->cur; + s->rle_run = 0; + s->rle_pos = 0; + s->rle_numvals = 0; + s->rle_out = dst; + s->page.encoding = + determine_encoding(s->page.page_type, physical_type, s->ck.use_dictionary, write_v2_headers); + if (dict_bits >= 0 && physical_type != BOOLEAN) { + dst[0] = dict_bits; + s->rle_out = dst + 1; + } else if (write_v2_headers && physical_type == BOOLEAN) { + // save space for RLE length. we don't know the total length yet. + s->rle_out = dst + RLE_LENGTH_FIELD_LEN; + s->rle_len_pos = dst; + } + s->page_start_val = row_to_value_idx(s->page.start_row, s->col); + s->chunk_start_val = row_to_value_idx(s->ck.start_row, s->col); + } + __syncthreads(); + + uint32_t num_valid = 0; + for (uint32_t cur_val_idx = 0; cur_val_idx < s->page.num_leaf_values;) { + uint32_t nvals = min(s->page.num_leaf_values - cur_val_idx, block_size); + + auto [is_valid, val_idx] = [&]() { + size_type const val_idx_in_block = cur_val_idx + t; + size_type const val_idx_in_leaf_col = s->page_start_val + val_idx_in_block; + + uint32_t const is_valid = (val_idx_in_leaf_col < s->col.leaf_column->size() && + val_idx_in_block < s->page.num_leaf_values) + ? s->col.leaf_column->is_valid(val_idx_in_leaf_col) + : 0; + // need to test for use_dictionary because it might be boolean + uint32_t const val_idx = + (s->ck.use_dictionary) ? val_idx_in_leaf_col - s->chunk_start_val : val_idx_in_leaf_col; + return std::make_tuple(is_valid, val_idx); + }(); + + if (is_valid) { num_valid++; } + cur_val_idx += nvals; + + // Dictionary encoding + if (dict_bits > 0) { + uint32_t rle_numvals; + uint32_t rle_numvals_in_block; + uint32_t pos; + block_scan(temp_storage.scan_storage).ExclusiveSum(is_valid, pos, rle_numvals_in_block); + rle_numvals = s->rle_numvals; + if (is_valid) { + uint32_t v; + if (physical_type == BOOLEAN) { + v = s->col.leaf_column->element(val_idx); + } else { + v = s->ck.dict_index[val_idx]; } + s->vals[rolling_idx(rle_numvals + pos)] = v; + } + rle_numvals += rle_numvals_in_block; + __syncthreads(); + if ((!write_v2_headers) && (physical_type == BOOLEAN)) { + PlainBoolEncode(s, rle_numvals, (cur_val_idx == s->page.num_leaf_values), t); + } else { + RleEncode(s, rle_numvals, dict_bits, (cur_val_idx == s->page.num_leaf_values), t); } __syncthreads(); } + if (t == 0) { s->cur = s->rle_out; } + __syncthreads(); } uint32_t const valid_count = block_reduce(temp_storage.reduce_storage).Sum(num_valid); @@ -1390,42 +1670,139 @@ __global__ void __launch_bounds__(128, 8) // save RLE length if necessary if (s->rle_len_pos != nullptr && t < 32) { // size doesn't include the 4 bytes for the length - auto const rle_size = static_cast(s->cur - s->rle_len_pos) - RLE_LENGTH_FIELD_LEN; - if (t < RLE_LENGTH_FIELD_LEN) { s->rle_len_pos[t] = rle_size >> (t * 8); } + auto const rle_size = static_cast(s->cur - s->rle_len_pos) - 4; + if (t < 4) { s->rle_len_pos[t] = rle_size >> (t * 8); } __syncwarp(); } - // V2 does not compress rep and def level data - size_t const skip_comp_size = s->page.def_lvl_bytes + s->page.rep_lvl_bytes; + finish_page_encode( + s, valid_count, s->cur, pages, comp_in, comp_out, comp_results, write_v2_headers); +} + +// FIXME: what should the args to launch_bounds be now? +// blockDim(128, 1, 1) +template +__global__ void __launch_bounds__(block_size, 8) + gpuEncodeDeltaBinaryPages(device_span pages, + device_span> comp_in, + device_span> comp_out, + device_span comp_results) +{ + // block of shared memory for value storage and bit packing + // TODO add constant that's the sum of buffer_size and block_size + __shared__ uint64_t delta_shared[delta::buffer_size + delta::block_size]; + __shared__ __align__(8) page_enc_state_s<0> state_g; + using block_reduce = cub::BlockReduce; + __shared__ union { + typename block_reduce::TempStorage reduce_storage; + typename delta::index_scan::TempStorage delta_index_tmp; + typename delta::block_reduce::TempStorage delta_reduce_tmp; + typename delta::warp_reduce::TempStorage delta_warp_red_tmp[delta::num_mini_blocks]; + } temp_storage; + + auto* const s = &state_g; + uint32_t t = threadIdx.x; if (t == 0) { - s->page.num_nulls = s->page.num_values - valid_count; - uint8_t* const base = s->page.page_data + s->page.max_hdr_size; - auto const actual_data_size = static_cast(s->cur - base); - if (actual_data_size > s->page.max_data_size) { - CUDF_UNREACHABLE("detected possible page data corruption"); - } - s->page.max_data_size = actual_data_size; - if (not comp_in.empty()) { - comp_in[blockIdx.x] = {base + skip_comp_size, actual_data_size - skip_comp_size}; - comp_out[blockIdx.x] = {s->page.compressed_data + s->page.max_hdr_size + skip_comp_size, - 0}; // size is unused - } - pages[blockIdx.x] = s->page; - if (not comp_results.empty()) { - comp_results[blockIdx.x] = {0, compression_status::FAILURE}; - pages[blockIdx.x].comp_res = &comp_results[blockIdx.x]; - } + state_g = page_enc_state_s<0>{}; + s->page = pages[blockIdx.x]; + s->ck = *s->page.chunk; + s->col = *s->ck.col_desc; + s->rle_len_pos = nullptr; + // get s->cur back to where it was at the end of encoding the rep and def level data + s->cur = + s->page.page_data + s->page.max_hdr_size + s->page.def_lvl_bytes + s->page.rep_lvl_bytes; } + __syncthreads(); - // copy over uncompressed data - if (skip_comp_size != 0 && not comp_in.empty()) { - uint8_t const* const src = s->page.page_data + s->page.max_hdr_size; - uint8_t* const dst = s->page.compressed_data + s->page.max_hdr_size; - for (int i = t; i < skip_comp_size; i += block_size) { - dst[i] = src[i]; + if ((s->page.kernel_mask & ENC_MASK_DELTA_BINARY) == 0) { return; } + + // Encode data values + __syncthreads(); + auto const physical_type = s->col.physical_type; + auto const type_id = s->col.leaf_column->type().id(); + auto const dtype_len_out = physical_type_len(physical_type, type_id); + auto const dtype_len_in = [&]() -> uint32_t { + if (physical_type == INT32) { return int32_logical_len(type_id); } + if (physical_type == INT96) { return sizeof(int64_t); } + return dtype_len_out; + }(); + + if (t == 0) { + uint8_t* dst = s->cur; + s->rle_run = 0; + s->rle_pos = 0; + s->rle_numvals = 0; + s->rle_out = dst; + s->page.encoding = Encoding::DELTA_BINARY_PACKED; + s->page_start_val = row_to_value_idx(s->page.start_row, s->col); + s->chunk_start_val = row_to_value_idx(s->ck.start_row, s->col); + } + __syncthreads(); + + // need to know the number of valid values for the null values calculation and to size + // the delta binary encoder. + uint32_t valid_count = 0; + if (not s->col.leaf_column->nullable()) { + valid_count = s->page.num_leaf_values; + } else { + uint32_t num_valid = 0; + for (uint32_t cur_val_idx = 0; cur_val_idx < s->page.num_leaf_values;) { + uint32_t const nvals = min(s->page.num_leaf_values - cur_val_idx, block_size); + size_type const val_idx_in_block = cur_val_idx + t; + size_type const val_idx_in_leaf_col = s->page_start_val + val_idx_in_block; + + if (val_idx_in_leaf_col < s->col.leaf_column->size() && + val_idx_in_block < s->page.num_leaf_values && + s->col.leaf_column->is_valid(val_idx_in_leaf_col)) { + num_valid++; + } + cur_val_idx += nvals; + } + valid_count = block_reduce(temp_storage.reduce_storage).Sum(num_valid); + } + + auto calc_idx_and_validity = [&](uint32_t cur_val_idx) { + size_type const val_idx_in_block = cur_val_idx + t; + size_type const val_idx_in_leaf_col = s->page_start_val + val_idx_in_block; + + uint32_t const is_valid = (val_idx_in_leaf_col < s->col.leaf_column->size() && + val_idx_in_block < s->page.num_leaf_values) + ? s->col.leaf_column->is_valid(val_idx_in_leaf_col) + : 0; + + return std::make_tuple(is_valid, val_idx_in_leaf_col); + }; + + uint8_t const* delta_ptr = nullptr; // this will be the end of delta block pointer + + if (physical_type == INT32) { + // FIXME need to handle all the time scaling stuff here too + if (dtype_len_in == 4) { + delta_enc encoder{ + s, valid_count, calc_idx_and_validity, delta_shared, &temp_storage}; + delta_ptr = encoder.encode(); + } else if (dtype_len_in == 2) { + delta_enc encoder{ + s, valid_count, calc_idx_and_validity, delta_shared, &temp_storage}; + delta_ptr = encoder.encode(); + } else if (dtype_len_in == 8) { + delta_enc encoder{ + s, valid_count, calc_idx_and_validity, delta_shared, &temp_storage}; + delta_ptr = encoder.encode(); + } else { + delta_enc encoder{ + s, valid_count, calc_idx_and_validity, delta_shared, &temp_storage}; + delta_ptr = encoder.encode(); } + } else { + delta_enc encoder{ + s, valid_count, calc_idx_and_validity, delta_shared, &temp_storage}; + delta_ptr = encoder.encode(); } + + finish_page_encode( + s, valid_count, delta_ptr, pages, comp_in, comp_out, comp_results, true); } constexpr int decide_compression_warps_in_block = 4; @@ -1460,7 +1837,8 @@ __global__ void __launch_bounds__(decide_compression_block_size) for (auto page_id = lane_id; page_id < num_pages; page_id += cudf::detail::warp_size) { auto const& curr_page = ck_g[warp_id].pages[page_id]; auto const page_data_size = curr_page.max_data_size; - auto const lvl_bytes = curr_page.def_lvl_bytes + curr_page.rep_lvl_bytes; + auto const is_v2 = curr_page.page_type == PageType::DATA_PAGE_V2; + auto const lvl_bytes = is_v2 ? curr_page.def_lvl_bytes + curr_page.rep_lvl_bytes : 0; uncompressed_data_size += page_data_size; if (auto comp_res = curr_page.comp_res; comp_res != nullptr) { compressed_data_size += comp_res->bytes_written + lvl_bytes; @@ -1923,7 +2301,8 @@ __global__ void __launch_bounds__(128) } uncompressed_page_size = page_g.max_data_size; if (ck_g.is_compressed) { - auto const lvl_bytes = page_g.def_lvl_bytes + page_g.rep_lvl_bytes; + auto const is_v2 = page_g.page_type == PageType::DATA_PAGE_V2; + auto const lvl_bytes = is_v2 ? page_g.def_lvl_bytes + page_g.rep_lvl_bytes : 0; hdr_start = page_g.compressed_data; compressed_page_size = static_cast(comp_results[blockIdx.x].bytes_written) + lvl_bytes; @@ -2158,6 +2537,10 @@ constexpr __device__ void* align8(void* ptr) return static_cast(ptr) - algn; } +struct mask_tform { + __device__ uint32_t operator()(EncPage const& p) { return p.kernel_mask; } +}; + } // namespace // blockDim(1, 1, 1) @@ -2260,8 +2643,9 @@ void InitFragmentStatistics(device_span groups, rmm::cuda_stream_view stream) { int const num_fragments = fragments.size(); - int const dim = util::div_rounding_up_safe(num_fragments, 128 / cudf::detail::warp_size); - gpuInitFragmentStats<<>>(groups, fragments); + int const dim = + util::div_rounding_up_safe(num_fragments, encode_block_size / cudf::detail::warp_size); + gpuInitFragmentStats<<>>(groups, fragments); } void InitEncoderPages(device_2dspan chunks, @@ -2280,18 +2664,18 @@ void InitEncoderPages(device_2dspan chunks, { auto num_rowgroups = chunks.size().first; dim3 dim_grid(num_columns, num_rowgroups); // 1 threadblock per rowgroup - gpuInitPages<<>>(chunks, - pages, - page_sizes, - comp_page_sizes, - col_desc, - page_grstats, - chunk_grstats, - num_columns, - max_page_size_bytes, - max_page_size_rows, - page_align, - write_v2_headers); + gpuInitPages<<>>(chunks, + pages, + page_sizes, + comp_page_sizes, + col_desc, + page_grstats, + chunk_grstats, + num_columns, + max_page_size_bytes, + max_page_size_rows, + page_align, + write_v2_headers); } void EncodePages(device_span pages, @@ -2302,10 +2686,43 @@ void EncodePages(device_span pages, rmm::cuda_stream_view stream) { auto num_pages = pages.size(); + + // determine which kernels to invoke + auto mask_iter = thrust::make_transform_iterator(pages.begin(), mask_tform{}); + int kernel_mask = thrust::reduce( + rmm::exec_policy(stream), mask_iter, mask_iter + pages.size(), 0U, thrust::bit_or{}); + + // get the number of streams we need from the pool + int nkernels = std::bitset<32>(kernel_mask).count(); + auto streams = cudf::detail::fork_streams(stream, nkernels); + // A page is part of one column. This is launching 1 block per page. 1 block will exclusively // deal with one datatype. - gpuEncodePages<128><<>>( - pages, comp_in, comp_out, comp_results, write_v2_headers); + + int s_idx = 0; + if ((kernel_mask & ENC_MASK_PLAIN) != 0) { + auto const strm = streams[s_idx++]; + gpuEncodePageLevels + <<>>(pages, write_v2_headers); + gpuEncodePages<<>>( + pages, comp_in, comp_out, comp_results, write_v2_headers); + } + if ((kernel_mask & ENC_MASK_DELTA_BINARY) != 0) { + auto const strm = streams[s_idx++]; + gpuEncodePageLevels + <<>>(pages, write_v2_headers); + gpuEncodeDeltaBinaryPages + <<>>(pages, comp_in, comp_out, comp_results); + } + if ((kernel_mask & ENC_MASK_DICTIONARY) != 0) { + auto const strm = streams[s_idx++]; + gpuEncodePageLevels + <<>>(pages, write_v2_headers); + gpuEncodeDictPages<<>>( + pages, comp_in, comp_out, comp_results, write_v2_headers); + } + + cudf::detail::join_streams(streams, stream); } void DecideCompression(device_span chunks, rmm::cuda_stream_view stream) @@ -2323,7 +2740,7 @@ void EncodePageHeaders(device_span pages, { // TODO: single thread task. No need for 128 threads/block. Earlier it used to employ rest of the // threads to coop load structs - gpuEncodePageHeaders<<>>( + gpuEncodePageHeaders<<>>( pages, comp_results, page_stats, chunk_stats); } diff --git a/cpp/src/io/parquet/parquet_gpu.hpp b/cpp/src/io/parquet/parquet_gpu.hpp index e82b6abc13d..c2892ed6495 100644 --- a/cpp/src/io/parquet/parquet_gpu.hpp +++ b/cpp/src/io/parquet/parquet_gpu.hpp @@ -75,6 +75,12 @@ struct input_column_info { namespace gpu { +// TODO: The delta encodings use ULEB128 integers, but for now we're only +// using max 64 bits. Need to see what the performance impact is of using +// __int128_t rather than int64_t. +using uleb128_t = uint64_t; +using zigzag128_t = int64_t; + /** * @brief Enums for the flags in the page header */ @@ -390,6 +396,17 @@ constexpr uint32_t encoding_to_mask(Encoding encoding) return 1 << static_cast(encoding); } +/** + * @brief Enum of mask bits for the EncPage kernel_mask + * + * Used to control which encode kernels to run. + */ +enum encoder_kernel_mask_bits { + ENC_MASK_PLAIN = (1 << 0), // Run plain encoding kernel + ENC_MASK_DICTIONARY = (1 << 1), // Run dictionary encoding kernel + ENC_MASK_DELTA_BINARY = (1 << 2) // Run DELTA_BINARY_PACKED encoding kernel +}; + /** * @brief Struct describing an encoder column chunk */ @@ -452,6 +469,7 @@ struct EncPage { uint32_t rep_lvl_bytes; //!< Number of bytes of encoded repetition level data (V2 only) compression_result* comp_res; //!< Ptr to compression result uint32_t num_nulls; //!< Number of null values (V2 only) (down here for alignment) + uint32_t kernel_mask; //!< Mask used to control which encoding kernels to run }; /** diff --git a/cpp/tests/io/parquet_test.cpp b/cpp/tests/io/parquet_test.cpp index 64aca091686..907805d0abd 100644 --- a/cpp/tests/io/parquet_test.cpp +++ b/cpp/tests/io/parquet_test.cpp @@ -540,7 +540,9 @@ TYPED_TEST(ParquetWriterTimestampTypeTest, Timestamps) auto filepath = temp_env->get_temp_filepath("Timestamps.parquet"); cudf::io::parquet_writer_options out_opts = - cudf::io::parquet_writer_options::builder(cudf::io::sink_info{filepath}, expected); + cudf::io::parquet_writer_options::builder(cudf::io::sink_info{filepath}, expected) + .write_v2_headers(true) + .dictionary_policy(cudf::io::dictionary_policy::NEVER); cudf::io::write_parquet(out_opts); cudf::io::parquet_reader_options in_opts = @@ -566,7 +568,9 @@ TYPED_TEST(ParquetWriterTimestampTypeTest, TimestampsWithNulls) auto filepath = temp_env->get_temp_filepath("TimestampsWithNulls.parquet"); cudf::io::parquet_writer_options out_opts = - cudf::io::parquet_writer_options::builder(cudf::io::sink_info{filepath}, expected); + cudf::io::parquet_writer_options::builder(cudf::io::sink_info{filepath}, expected) + .write_v2_headers(true) + .dictionary_policy(cudf::io::dictionary_policy::NEVER); cudf::io::write_parquet(out_opts); cudf::io::parquet_reader_options in_opts = @@ -590,7 +594,9 @@ TYPED_TEST(ParquetWriterTimestampTypeTest, TimestampOverflow) auto filepath = temp_env->get_temp_filepath("ParquetTimestampOverflow.parquet"); cudf::io::parquet_writer_options out_opts = - cudf::io::parquet_writer_options::builder(cudf::io::sink_info{filepath}, expected); + cudf::io::parquet_writer_options::builder(cudf::io::sink_info{filepath}, expected) + .write_v2_headers(true) + .dictionary_policy(cudf::io::dictionary_policy::NEVER); cudf::io::write_parquet(out_opts); cudf::io::parquet_reader_options in_opts = @@ -6672,7 +6678,7 @@ TEST_P(ParquetV2Test, CheckEncodings) // data should be PLAIN for v1, RLE for V2 auto col0_data = cudf::detail::make_counting_transform_iterator(0, [](auto i) -> bool { return i % 2 == 0; }); - // data should be PLAIN for both + // data should be PLAIN for v1, DELTA_BINARY_PACKED for v2 auto col1_data = random_values(num_rows); // data should be PLAIN_DICTIONARY for v1, PLAIN and RLE_DICTIONARY for v2 auto col2_data = cudf::detail::make_counting_transform_iterator(0, [](auto i) { return 1; }); @@ -6707,10 +6713,10 @@ TEST_P(ParquetV2Test, CheckEncodings) // col0 should have RLE for rep/def and data EXPECT_TRUE(chunk0_enc.size() == 1); EXPECT_TRUE(contains(chunk0_enc, Encoding::RLE)); - // col1 should have RLE for rep/def and PLAIN for data + // col1 should have RLE for rep/def and DELTA_BINARY_PACKED for data EXPECT_TRUE(chunk1_enc.size() == 2); EXPECT_TRUE(contains(chunk1_enc, Encoding::RLE)); - EXPECT_TRUE(contains(chunk1_enc, Encoding::PLAIN)); + EXPECT_TRUE(contains(chunk1_enc, Encoding::DELTA_BINARY_PACKED)); // col2 should have RLE for rep/def, PLAIN for dict, and RLE_DICTIONARY for data EXPECT_TRUE(chunk2_enc.size() == 3); EXPECT_TRUE(contains(chunk2_enc, Encoding::RLE)); From 478f8acd7abc3074acaaa6b9ce3ef64aea56cee6 Mon Sep 17 00:00:00 2001 From: seidl Date: Wed, 13 Sep 2023 10:31:51 -0700 Subject: [PATCH 02/37] clean up some consts and change kernel mask to enum class --- cpp/src/io/parquet/delta_enc.cuh | 11 +++---- cpp/src/io/parquet/page_enc.cu | 47 +++++++++++++++++++----------- cpp/src/io/parquet/parquet_gpu.hpp | 10 +++---- 3 files changed, 41 insertions(+), 27 deletions(-) diff --git a/cpp/src/io/parquet/delta_enc.cuh b/cpp/src/io/parquet/delta_enc.cuh index 164849edd63..471774b2561 100644 --- a/cpp/src/io/parquet/delta_enc.cuh +++ b/cpp/src/io/parquet/delta_enc.cuh @@ -19,6 +19,7 @@ #include "parquet_gpu.hpp" #include +#include #include @@ -59,7 +60,7 @@ template inline __device__ void bitpack_mini_block( uint8_t* dst, T val, uint32_t count, uint8_t nbits, void* temp_space) { - // typing for atomicOr is annoying + // typing for atomicOr is annoying. uint64_t doesn't work, need unsigned long long. using scratch_type = std::conditional_t, unsigned long long, uint32_t>; using cudf::detail::warp_size; @@ -69,7 +70,7 @@ inline __device__ void bitpack_mini_block( auto const lane_id = threadIdx.x % warp_size; auto const warp_id = threadIdx.x / warp_size; - scratch_type* scratch = reinterpret_cast(temp_space) + warp_id * warp_size; + auto const scratch = reinterpret_cast(temp_space) + warp_id * warp_size; // zero out scratch scratch[lane_id] = 0; @@ -103,9 +104,9 @@ inline __device__ void bitpack_mini_block( __syncwarp(); // Copy scratch data to final destination - auto available_bytes = (count * nbits + 7) / 8; + auto const available_bytes = util::div_rounding_up_safe(count * nbits, 8U); + auto const scratch_bytes = reinterpret_cast(scratch); - auto scratch_bytes = reinterpret_cast(scratch); for (uint32_t i = lane_id; i < available_bytes; i += warp_size) { dst[i] = scratch_bytes[i]; } @@ -254,7 +255,7 @@ class DeltaBinaryPacker { mb_ptr, norm_delta, num_enc, _mb_bits[warp_id], _bitpack_tmp); } - // last lane updates global delta ptr + // last warp updates global delta ptr if (warp_id == delta::num_mini_blocks - 1 && lane_id == 0) { _dst = mb_ptr + _mb_bits[warp_id] * delta::values_per_mini_block / 8; _current_idx = min(warp_idx + delta::values_per_mini_block, _num_values); diff --git a/cpp/src/io/parquet/page_enc.cu b/cpp/src/io/parquet/page_enc.cu index fe212ec6714..568e8bd83fb 100644 --- a/cpp/src/io/parquet/page_enc.cu +++ b/cpp/src/io/parquet/page_enc.cu @@ -75,6 +75,19 @@ constexpr uint32_t MAX_GRID_Y_SIZE = (1 << 16) - 1; // space needed for RLE length field constexpr int RLE_LENGTH_FIELD_LEN = 4; +// helpers to do bit operations on enums +template , bool> = true> +constexpr uint32_t operator&(Enum a, Enum b) +{ + return static_cast(a) & static_cast(b); +} + +template , bool> = true> +constexpr uint32_t operator&(uint32_t a, Enum b) +{ + return a & static_cast(b); +} + struct frag_init_state_s { parquet_column_device_view col; PageFragment frag; @@ -491,7 +504,7 @@ __global__ void __launch_bounds__(128) __syncwarp(); if (t == 0) { if (not pages.empty()) { - page_g.kernel_mask = ENC_MASK_PLAIN; + page_g.kernel_mask = EncodeKernelMask::PLAIN; pages[ck_g.first_page] = page_g; } if (not page_sizes.empty()) { page_sizes[ck_g.first_page] = page_g.max_data_size; } @@ -622,11 +635,11 @@ __global__ void __launch_bounds__(128) if (t == 0) { if (not pages.empty()) { if (is_use_delta) { - page_g.kernel_mask = ENC_MASK_DELTA_BINARY; + page_g.kernel_mask = EncodeKernelMask::DELTA_BINARY; } else if (ck_g.use_dictionary || physical_type == BOOLEAN) { - page_g.kernel_mask = ENC_MASK_DICTIONARY; + page_g.kernel_mask = EncodeKernelMask::DICTIONARY; } else { - page_g.kernel_mask = ENC_MASK_PLAIN; + page_g.kernel_mask = EncodeKernelMask::PLAIN; } pages[ck_g.first_page + num_pages] = page_g; } @@ -1104,7 +1117,7 @@ __device__ auto julian_days_with_time(int64_t v) // the level data is encoded. // FIXME: what should the args to launch_bounds be now? // blockDim(128, 1, 1) -template +template __global__ void __launch_bounds__(block_size, 8) gpuEncodePageLevels(device_span pages, bool write_v2_headers) { @@ -1334,7 +1347,7 @@ __global__ void __launch_bounds__(block_size, 8) } __syncthreads(); - if ((s->page.kernel_mask & ENC_MASK_PLAIN) == 0) { return; } + if ((s->page.kernel_mask & EncodeKernelMask::PLAIN) == 0) { return; } // Encode data values __syncthreads(); @@ -1576,7 +1589,7 @@ __global__ void __launch_bounds__(block_size, 8) } __syncthreads(); - if ((s->page.kernel_mask & ENC_MASK_DICTIONARY) == 0) { return; } + if ((s->page.kernel_mask & EncodeKernelMask::DICTIONARY) == 0) { return; } // Encode data values __syncthreads(); @@ -1715,7 +1728,7 @@ __global__ void __launch_bounds__(block_size, 8) } __syncthreads(); - if ((s->page.kernel_mask & ENC_MASK_DELTA_BINARY) == 0) { return; } + if ((s->page.kernel_mask & EncodeKernelMask::DELTA_BINARY) == 0) { return; } // Encode data values __syncthreads(); @@ -2538,7 +2551,7 @@ constexpr __device__ void* align8(void* ptr) } struct mask_tform { - __device__ uint32_t operator()(EncPage const& p) { return p.kernel_mask; } + __device__ uint32_t operator()(EncPage const& p) { return static_cast(p.kernel_mask); } }; } // namespace @@ -2688,8 +2701,8 @@ void EncodePages(device_span pages, auto num_pages = pages.size(); // determine which kernels to invoke - auto mask_iter = thrust::make_transform_iterator(pages.begin(), mask_tform{}); - int kernel_mask = thrust::reduce( + auto mask_iter = thrust::make_transform_iterator(pages.begin(), mask_tform{}); + uint32_t kernel_mask = thrust::reduce( rmm::exec_policy(stream), mask_iter, mask_iter + pages.size(), 0U, thrust::bit_or{}); // get the number of streams we need from the pool @@ -2700,23 +2713,23 @@ void EncodePages(device_span pages, // deal with one datatype. int s_idx = 0; - if ((kernel_mask & ENC_MASK_PLAIN) != 0) { + if ((kernel_mask & EncodeKernelMask::PLAIN) != 0) { auto const strm = streams[s_idx++]; - gpuEncodePageLevels + gpuEncodePageLevels <<>>(pages, write_v2_headers); gpuEncodePages<<>>( pages, comp_in, comp_out, comp_results, write_v2_headers); } - if ((kernel_mask & ENC_MASK_DELTA_BINARY) != 0) { + if ((kernel_mask & EncodeKernelMask::DELTA_BINARY) != 0) { auto const strm = streams[s_idx++]; - gpuEncodePageLevels + gpuEncodePageLevels <<>>(pages, write_v2_headers); gpuEncodeDeltaBinaryPages <<>>(pages, comp_in, comp_out, comp_results); } - if ((kernel_mask & ENC_MASK_DICTIONARY) != 0) { + if ((kernel_mask & EncodeKernelMask::DICTIONARY) != 0) { auto const strm = streams[s_idx++]; - gpuEncodePageLevels + gpuEncodePageLevels <<>>(pages, write_v2_headers); gpuEncodeDictPages<<>>( pages, comp_in, comp_out, comp_results, write_v2_headers); diff --git a/cpp/src/io/parquet/parquet_gpu.hpp b/cpp/src/io/parquet/parquet_gpu.hpp index c2892ed6495..167bac24bc8 100644 --- a/cpp/src/io/parquet/parquet_gpu.hpp +++ b/cpp/src/io/parquet/parquet_gpu.hpp @@ -401,10 +401,10 @@ constexpr uint32_t encoding_to_mask(Encoding encoding) * * Used to control which encode kernels to run. */ -enum encoder_kernel_mask_bits { - ENC_MASK_PLAIN = (1 << 0), // Run plain encoding kernel - ENC_MASK_DICTIONARY = (1 << 1), // Run dictionary encoding kernel - ENC_MASK_DELTA_BINARY = (1 << 2) // Run DELTA_BINARY_PACKED encoding kernel +enum class EncodeKernelMask { + PLAIN = (1 << 0), // Run plain encoding kernel + DICTIONARY = (1 << 1), // Run dictionary encoding kernel + DELTA_BINARY = (1 << 2) // Run DELTA_BINARY_PACKED encoding kernel }; /** @@ -469,7 +469,7 @@ struct EncPage { uint32_t rep_lvl_bytes; //!< Number of bytes of encoded repetition level data (V2 only) compression_result* comp_res; //!< Ptr to compression result uint32_t num_nulls; //!< Number of null values (V2 only) (down here for alignment) - uint32_t kernel_mask; //!< Mask used to control which encoding kernels to run + EncodeKernelMask kernel_mask; //!< Mask used to control which encoding kernels to run }; /** From cd4df51659167f1f5ece6dec9c8b3db138bb61eb Mon Sep 17 00:00:00 2001 From: seidl Date: Wed, 13 Sep 2023 10:54:56 -0700 Subject: [PATCH 03/37] more cleanup --- cpp/src/io/parquet/page_enc.cu | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/cpp/src/io/parquet/page_enc.cu b/cpp/src/io/parquet/page_enc.cu index 568e8bd83fb..78d79b644d8 100644 --- a/cpp/src/io/parquet/page_enc.cu +++ b/cpp/src/io/parquet/page_enc.cu @@ -266,11 +266,11 @@ struct BitwiseOr { // F is a function that computes validity and the src index for a given input position template struct delta_enc { - page_enc_state_s<0>* s; - uint32_t valid_count; + page_enc_state_s<0>* const s; + uint32_t const valid_count; F& f; - uint64_t* buffer; - void* temp_space; + uint64_t* const buffer; + void* const temp_space; __device__ uint8_t const* encode() { @@ -281,7 +281,7 @@ struct delta_enc { if (t == 0) { packer.init(s->cur, valid_count, reinterpret_cast(buffer), temp_space); } __syncthreads(); - // FIXME int the plain encoder the scaling is a little different for INT32 than INT64. + // FIXME(ets): in the plain encoder the scaling is a little different for INT32 than INT64. // might need to patch this up some. int32_t const scale = s->col.ts_scale == 0 ? 1 : s->col.ts_scale; for (uint32_t cur_val_idx = 0; cur_val_idx < s->page.num_leaf_values;) { @@ -1115,7 +1115,7 @@ __device__ auto julian_days_with_time(int64_t v) // this has been split out into its own kernel because of the amount of shared memory required // for the state buffer. encode kernels that don't use the RLE buffer can get started while // the level data is encoded. -// FIXME: what should the args to launch_bounds be now? +// FIXME(ets): what should the args to launch_bounds be now? // blockDim(128, 1, 1) template __global__ void __launch_bounds__(block_size, 8) @@ -1302,15 +1302,15 @@ __device__ void finish_page_encode(state_buf* s, // copy uncompressed bytes over if (skip_comp_size != 0 && not comp_in.empty()) { - uint8_t* src = s->page.page_data + s->page.max_hdr_size; - uint8_t* dst = s->page.compressed_data + s->page.max_hdr_size; + uint8_t* const src = s->page.page_data + s->page.max_hdr_size; + uint8_t* const dst = s->page.compressed_data + s->page.max_hdr_size; for (int i = t; i < skip_comp_size; i += block_size) { dst[i] = src[i]; } } } -// FIXME: what should the args to launch_bounds be now? +// FIXME(ets): what should the args to launch_bounds be now? // blockDim(128, 1, 1) template __global__ void __launch_bounds__(block_size, 8) @@ -1552,7 +1552,7 @@ __global__ void __launch_bounds__(block_size, 8) s, valid_count, s->cur, pages, comp_in, comp_out, comp_results, write_v2_headers); } -// FIXME: what should the args to launch_bounds be now? +// FIXME(ets): what should the args to launch_bounds be now? // blockDim(128, 1, 1) template __global__ void __launch_bounds__(block_size, 8) @@ -1790,7 +1790,6 @@ __global__ void __launch_bounds__(block_size, 8) uint8_t const* delta_ptr = nullptr; // this will be the end of delta block pointer if (physical_type == INT32) { - // FIXME need to handle all the time scaling stuff here too if (dtype_len_in == 4) { delta_enc encoder{ s, valid_count, calc_idx_and_validity, delta_shared, &temp_storage}; From 2e33420a3d9d6805b7d1f2ebccb17d91b462841d Mon Sep 17 00:00:00 2001 From: seidl Date: Wed, 13 Sep 2023 11:09:44 -0700 Subject: [PATCH 04/37] remove some FIXMEs --- cpp/src/io/parquet/page_enc.cu | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/cpp/src/io/parquet/page_enc.cu b/cpp/src/io/parquet/page_enc.cu index 78d79b644d8..38212c0ead2 100644 --- a/cpp/src/io/parquet/page_enc.cu +++ b/cpp/src/io/parquet/page_enc.cu @@ -1115,7 +1115,6 @@ __device__ auto julian_days_with_time(int64_t v) // this has been split out into its own kernel because of the amount of shared memory required // for the state buffer. encode kernels that don't use the RLE buffer can get started while // the level data is encoded. -// FIXME(ets): what should the args to launch_bounds be now? // blockDim(128, 1, 1) template __global__ void __launch_bounds__(block_size, 8) @@ -1310,7 +1309,7 @@ __device__ void finish_page_encode(state_buf* s, } } -// FIXME(ets): what should the args to launch_bounds be now? +// PLAIN page data encoder // blockDim(128, 1, 1) template __global__ void __launch_bounds__(block_size, 8) @@ -1552,7 +1551,7 @@ __global__ void __launch_bounds__(block_size, 8) s, valid_count, s->cur, pages, comp_in, comp_out, comp_results, write_v2_headers); } -// FIXME(ets): what should the args to launch_bounds be now? +// DICTIONARY page data encoder // blockDim(128, 1, 1) template __global__ void __launch_bounds__(block_size, 8) @@ -1692,7 +1691,7 @@ __global__ void __launch_bounds__(block_size, 8) s, valid_count, s->cur, pages, comp_in, comp_out, comp_results, write_v2_headers); } -// FIXME: what should the args to launch_bounds be now? +// DELTA_BINARY_PACKED page data encoder // blockDim(128, 1, 1) template __global__ void __launch_bounds__(block_size, 8) From 0bd64c851b275a0152b5028a5ab063b2ee359627 Mon Sep 17 00:00:00 2001 From: seidl Date: Wed, 13 Sep 2023 11:16:59 -0700 Subject: [PATCH 05/37] change FIXME to TODO for TS scaling --- cpp/src/io/parquet/page_enc.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/io/parquet/page_enc.cu b/cpp/src/io/parquet/page_enc.cu index 38212c0ead2..570ed3d8314 100644 --- a/cpp/src/io/parquet/page_enc.cu +++ b/cpp/src/io/parquet/page_enc.cu @@ -281,8 +281,8 @@ struct delta_enc { if (t == 0) { packer.init(s->cur, valid_count, reinterpret_cast(buffer), temp_space); } __syncthreads(); - // FIXME(ets): in the plain encoder the scaling is a little different for INT32 than INT64. - // might need to patch this up some. + // TODO(ets): in the plain encoder the scaling is a little different for INT32 than INT64. + // might need to modify this if there's a big performance hit in the 32-bit case. int32_t const scale = s->col.ts_scale == 0 ? 1 : s->col.ts_scale; for (uint32_t cur_val_idx = 0; cur_val_idx < s->page.num_leaf_values;) { uint32_t nvals = min(s->page.num_leaf_values - cur_val_idx, delta::block_size); From fffc659e05ccf57fe4985234adef8e34406332f1 Mon Sep 17 00:00:00 2001 From: seidl Date: Wed, 13 Sep 2023 13:48:53 -0700 Subject: [PATCH 06/37] use switch rather than if/else block --- cpp/src/io/parquet/page_enc.cu | 42 +++++++++++++++++++++------------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/cpp/src/io/parquet/page_enc.cu b/cpp/src/io/parquet/page_enc.cu index 570ed3d8314..74f1ac2260a 100644 --- a/cpp/src/io/parquet/page_enc.cu +++ b/cpp/src/io/parquet/page_enc.cu @@ -1789,22 +1789,32 @@ __global__ void __launch_bounds__(block_size, 8) uint8_t const* delta_ptr = nullptr; // this will be the end of delta block pointer if (physical_type == INT32) { - if (dtype_len_in == 4) { - delta_enc encoder{ - s, valid_count, calc_idx_and_validity, delta_shared, &temp_storage}; - delta_ptr = encoder.encode(); - } else if (dtype_len_in == 2) { - delta_enc encoder{ - s, valid_count, calc_idx_and_validity, delta_shared, &temp_storage}; - delta_ptr = encoder.encode(); - } else if (dtype_len_in == 8) { - delta_enc encoder{ - s, valid_count, calc_idx_and_validity, delta_shared, &temp_storage}; - delta_ptr = encoder.encode(); - } else { - delta_enc encoder{ - s, valid_count, calc_idx_and_validity, delta_shared, &temp_storage}; - delta_ptr = encoder.encode(); + switch (dtype_len_in) { + case 8: { + delta_enc encoder{ + s, valid_count, calc_idx_and_validity, delta_shared, &temp_storage}; + delta_ptr = encoder.encode(); + break; + } + case 4: { + delta_enc encoder{ + s, valid_count, calc_idx_and_validity, delta_shared, &temp_storage}; + delta_ptr = encoder.encode(); + break; + } + case 2: { + delta_enc encoder{ + s, valid_count, calc_idx_and_validity, delta_shared, &temp_storage}; + delta_ptr = encoder.encode(); + break; + } + case 1: { + delta_enc encoder{ + s, valid_count, calc_idx_and_validity, delta_shared, &temp_storage}; + delta_ptr = encoder.encode(); + break; + } + default: CUDF_UNREACHABLE("invalid dtype_len_in when encoding DELTA_BINARY_PACKED"); } } else { delta_enc encoder{ From a98b338a6dcfe211a826808f5ebe4d62bdeda16c Mon Sep 17 00:00:00 2001 From: seidl Date: Wed, 13 Sep 2023 14:25:29 -0700 Subject: [PATCH 07/37] get rid of some magic numbers --- cpp/src/io/parquet/page_enc.cu | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/cpp/src/io/parquet/page_enc.cu b/cpp/src/io/parquet/page_enc.cu index 74f1ac2260a..ac29d331a17 100644 --- a/cpp/src/io/parquet/page_enc.cu +++ b/cpp/src/io/parquet/page_enc.cu @@ -1200,10 +1200,12 @@ __global__ void __launch_bounds__(block_size, 8) __syncthreads(); } if (t < 32) { - uint8_t* const cur = s->cur; - uint8_t* const rle_out = s->rle_out; - uint32_t const rle_bytes = static_cast(rle_out - cur) - (is_v2 ? 0 : 4); - if (not is_v2 && t < 4) { cur[t] = rle_bytes >> (t * 8); } + uint8_t* const cur = s->cur; + uint8_t* const rle_out = s->rle_out; + // V2 does not write the RLE length field + uint32_t const rle_bytes = + static_cast(rle_out - cur) - (is_v2 ? 0 : RLE_LENGTH_FIELD_LEN); + if (not is_v2 && t < RLE_LENGTH_FIELD_LEN) { cur[t] = rle_bytes >> (t * 8); } __syncwarp(); if (t == 0) { s->cur = rle_out; @@ -1242,10 +1244,12 @@ __global__ void __launch_bounds__(block_size, 8) __syncthreads(); } if (t < 32) { - uint8_t* const cur = s->cur; - uint8_t* const rle_out = s->rle_out; - uint32_t const rle_bytes = static_cast(rle_out - cur) - (is_v2 ? 0 : 4); - if (not is_v2 && t < 4) { cur[t] = rle_bytes >> (t * 8); } + uint8_t* const cur = s->cur; + uint8_t* const rle_out = s->rle_out; + // V2 does not write the RLE length field + uint32_t const rle_bytes = + static_cast(rle_out - cur) - (is_v2 ? 0 : RLE_LENGTH_FIELD_LEN); + if (not is_v2 && t < RLE_LENGTH_FIELD_LEN) { cur[t] = rle_bytes >> (t * 8); } __syncwarp(); if (t == 0) { s->cur = rle_out; @@ -1339,9 +1343,10 @@ __global__ void __launch_bounds__(block_size, 8) // get s->cur back to where it was at the end of encoding the rep and def level data s->cur = s->page.page_data + s->page.max_hdr_size + s->page.def_lvl_bytes + s->page.rep_lvl_bytes; + // if V1 data page, need space for the RLE length fields if (s->page.page_type == PageType::DATA_PAGE) { - if (s->col.num_def_level_bits() != 0) { s->cur += 4; } - if (s->col.num_rep_level_bits() != 0) { s->cur += 4; } + if (s->col.num_def_level_bits() != 0) { s->cur += RLE_LENGTH_FIELD_LEN; } + if (s->col.num_rep_level_bits() != 0) { s->cur += RLE_LENGTH_FIELD_LEN; } } } __syncthreads(); @@ -1581,9 +1586,10 @@ __global__ void __launch_bounds__(block_size, 8) // get s->cur back to where it was at the end of encoding the rep and def level data s->cur = s->page.page_data + s->page.max_hdr_size + s->page.def_lvl_bytes + s->page.rep_lvl_bytes; + // if V1 data page, need space for the RLE length fields if (s->page.page_type == PageType::DATA_PAGE) { - if (s->col.num_def_level_bits() != 0) { s->cur += 4; } - if (s->col.num_rep_level_bits() != 0) { s->cur += 4; } + if (s->col.num_def_level_bits() != 0) { s->cur += RLE_LENGTH_FIELD_LEN; } + if (s->col.num_rep_level_bits() != 0) { s->cur += RLE_LENGTH_FIELD_LEN; } } } __syncthreads(); From 9d1b2f232cc59921bc51f46b64d48bdf30d92d1b Mon Sep 17 00:00:00 2001 From: seidl Date: Wed, 13 Sep 2023 16:50:17 -0700 Subject: [PATCH 08/37] replace operator overload with templated function --- cpp/src/io/parquet/page_enc.cu | 27 +++++++-------------------- cpp/src/io/parquet/parquet_gpu.hpp | 20 ++++++++++++++++++++ 2 files changed, 27 insertions(+), 20 deletions(-) diff --git a/cpp/src/io/parquet/page_enc.cu b/cpp/src/io/parquet/page_enc.cu index ac29d331a17..fad6dac5730 100644 --- a/cpp/src/io/parquet/page_enc.cu +++ b/cpp/src/io/parquet/page_enc.cu @@ -75,19 +75,6 @@ constexpr uint32_t MAX_GRID_Y_SIZE = (1 << 16) - 1; // space needed for RLE length field constexpr int RLE_LENGTH_FIELD_LEN = 4; -// helpers to do bit operations on enums -template , bool> = true> -constexpr uint32_t operator&(Enum a, Enum b) -{ - return static_cast(a) & static_cast(b); -} - -template , bool> = true> -constexpr uint32_t operator&(uint32_t a, Enum b) -{ - return a & static_cast(b); -} - struct frag_init_state_s { parquet_column_device_view col; PageFragment frag; @@ -1139,7 +1126,7 @@ __global__ void __launch_bounds__(block_size, 8) } __syncthreads(); - if ((s->page.kernel_mask & kernel_mask) == 0) { return; } + if (BitAnd(s->page.kernel_mask, kernel_mask) == 0) { return; } auto const is_v2 = s->page.page_type == PageType::DATA_PAGE_V2; @@ -1351,7 +1338,7 @@ __global__ void __launch_bounds__(block_size, 8) } __syncthreads(); - if ((s->page.kernel_mask & EncodeKernelMask::PLAIN) == 0) { return; } + if (BitAnd(s->page.kernel_mask, EncodeKernelMask::PLAIN) == 0) { return; } // Encode data values __syncthreads(); @@ -1594,7 +1581,7 @@ __global__ void __launch_bounds__(block_size, 8) } __syncthreads(); - if ((s->page.kernel_mask & EncodeKernelMask::DICTIONARY) == 0) { return; } + if (BitAnd(s->page.kernel_mask, EncodeKernelMask::DICTIONARY) == 0) { return; } // Encode data values __syncthreads(); @@ -1733,7 +1720,7 @@ __global__ void __launch_bounds__(block_size, 8) } __syncthreads(); - if ((s->page.kernel_mask & EncodeKernelMask::DELTA_BINARY) == 0) { return; } + if (BitAnd(s->page.kernel_mask, EncodeKernelMask::DELTA_BINARY) == 0) { return; } // Encode data values __syncthreads(); @@ -2727,21 +2714,21 @@ void EncodePages(device_span pages, // deal with one datatype. int s_idx = 0; - if ((kernel_mask & EncodeKernelMask::PLAIN) != 0) { + if (BitAnd(kernel_mask, EncodeKernelMask::PLAIN) != 0) { auto const strm = streams[s_idx++]; gpuEncodePageLevels <<>>(pages, write_v2_headers); gpuEncodePages<<>>( pages, comp_in, comp_out, comp_results, write_v2_headers); } - if ((kernel_mask & EncodeKernelMask::DELTA_BINARY) != 0) { + if (BitAnd(kernel_mask, EncodeKernelMask::DELTA_BINARY) != 0) { auto const strm = streams[s_idx++]; gpuEncodePageLevels <<>>(pages, write_v2_headers); gpuEncodeDeltaBinaryPages <<>>(pages, comp_in, comp_out, comp_results); } - if ((kernel_mask & EncodeKernelMask::DICTIONARY) != 0) { + if (BitAnd(kernel_mask, EncodeKernelMask::DICTIONARY) != 0) { auto const strm = streams[s_idx++]; gpuEncodePageLevels <<>>(pages, write_v2_headers); diff --git a/cpp/src/io/parquet/parquet_gpu.hpp b/cpp/src/io/parquet/parquet_gpu.hpp index 167bac24bc8..a3daab9d442 100644 --- a/cpp/src/io/parquet/parquet_gpu.hpp +++ b/cpp/src/io/parquet/parquet_gpu.hpp @@ -81,6 +81,26 @@ namespace gpu { using uleb128_t = uint64_t; using zigzag128_t = int64_t; +// TODO this is in C++23 +template +struct is_scoped_enum { + static const bool value = + std::is_enum_v and not std::is_convertible_v>; +}; + +// helpers to do bit operations on scoped enums +template ::value, bool> = true> +constexpr uint32_t BitAnd(Enum a, Enum b) +{ + return static_cast(a) & static_cast(b); +} + +template ::value, bool> = true> +constexpr uint32_t BitAnd(uint32_t a, Enum b) +{ + return a & static_cast(b); +} + /** * @brief Enums for the flags in the page header */ From ea6f3c0f232051f39c72a98b679c20371a0395da Mon Sep 17 00:00:00 2001 From: seidl Date: Fri, 15 Sep 2023 11:23:08 -0700 Subject: [PATCH 09/37] redo typing of the delta bit packer --- cpp/src/io/parquet/delta_enc.cuh | 50 +++++++++-------- cpp/src/io/parquet/page_enc.cu | 94 +++++++++++++++++++------------- 2 files changed, 83 insertions(+), 61 deletions(-) diff --git a/cpp/src/io/parquet/delta_enc.cuh b/cpp/src/io/parquet/delta_enc.cuh index 471774b2561..dab42227827 100644 --- a/cpp/src/io/parquet/delta_enc.cuh +++ b/cpp/src/io/parquet/delta_enc.cuh @@ -56,16 +56,18 @@ using index_scan = cub::BlockScan; constexpr int rolling_idx(int index) { return rolling_index(index); } // version of bit packer that can handle up to 64 bits values. -template +// T is the type to use for processing. if nbits <= 32 use uint32_t, otherwise unsigned long long +// (not uint64_t because of atomicOr's typing). allowing this to be selectable since there's a +// measurable impact to using the wider types. +template inline __device__ void bitpack_mini_block( - uint8_t* dst, T val, uint32_t count, uint8_t nbits, void* temp_space) + uint8_t* dst, uleb128_t val, uint32_t count, uint8_t nbits, void* temp_space) { - // typing for atomicOr is annoying. uint64_t doesn't work, need unsigned long long. - using scratch_type = - std::conditional_t, unsigned long long, uint32_t>; + using wide_type = + std::conditional_t, __uint128_t, uint64_t>; using cudf::detail::warp_size; - T constexpr mask = sizeof(T) * 8 - 1; - auto constexpr div = sizeof(T) * 8; + scratch_type constexpr mask = sizeof(scratch_type) * 8 - 1; + auto constexpr div = sizeof(scratch_type) * 8; auto const lane_id = threadIdx.x % warp_size; auto const warp_id = threadIdx.x / warp_size; @@ -76,12 +78,13 @@ inline __device__ void bitpack_mini_block( scratch[lane_id] = 0; __syncwarp(); - // why use bit packing when there's no savings??? + // TODO: see if there is any savings using special packing for easy bitwidths (1,2,4,8,16...) + // like what's done for the RLE encoder. if (nbits == div) { if (lane_id < count) { - for (int i = 0; i < sizeof(T); i++) { - dst[lane_id * sizeof(T) + i] = val & 0xff; - if constexpr (sizeof(T) > 1) { val >>= 8; } + for (int i = 0; i < sizeof(scratch_type); i++) { + dst[lane_id * sizeof(scratch_type) + i] = val & 0xff; + val >>= 8; } } __syncwarp(); @@ -90,12 +93,12 @@ inline __device__ void bitpack_mini_block( if (lane_id <= count) { // shift symbol left by up to mask bits - WideType v2 = val; + wide_type v2 = val; v2 <<= (lane_id * nbits) & mask; // Copy N bit word into two N/2 bit words while following C++ strict aliasing rules. - T v1[2]; - memcpy(&v1, &v2, sizeof(WideType)); + scratch_type v1[2]; + memcpy(&v1, &v2, sizeof(wide_type)); // Atomically write result to scratch if (v1[0]) { atomicOr(scratch + ((lane_id * nbits) / div), v1[0]); } @@ -118,14 +121,10 @@ inline __device__ void bitpack_mini_block( // Object used to turn a stream of integers into a DELTA_BINARY_PACKED stream. This takes as input // 128 values with validity at a time, saving them until there are enough values for a block // to be written. -// -// T can only be uint32_t or uint64_t since the DELTA_BINARY_PACKED encoding is only defined for -// INT32 and INT64 physical types -template +// T is the input data type (either zigzag128_t or uleb128_t) +template class DeltaBinaryPacker { private: - // static_assert(std::is_same_v || std::is_same_v); - uint8_t* _dst; // sink to dump encoded values to size_type _current_idx; // index of first value in buffer uint32_t _num_values; // total number of values to encode @@ -231,6 +230,7 @@ class DeltaBinaryPacker { // for the bitpacking of this warp zigzag128_t const warp_max = delta::warp_reduce(_warp_tmp[warp_id]).Reduce(norm_delta, cub::Max()); + __syncthreads(); if (lane_id == 0) { _mb_bits[warp_id] = sizeof(zigzag128_t) * 8 - __clzll(warp_max); } __syncthreads(); @@ -251,9 +251,15 @@ class DeltaBinaryPacker { auto const warp_idx = _current_idx + warp_id * delta::values_per_mini_block; if (warp_idx < _num_values) { auto const num_enc = min(delta::values_per_mini_block, _num_values - warp_idx); - delta::bitpack_mini_block( - mb_ptr, norm_delta, num_enc, _mb_bits[warp_id], _bitpack_tmp); + if (_mb_bits[warp_id] > 32) { + delta::bitpack_mini_block( + mb_ptr, norm_delta, num_enc, _mb_bits[warp_id], _bitpack_tmp); + } else { + delta::bitpack_mini_block( + mb_ptr, norm_delta, num_enc, _mb_bits[warp_id], _bitpack_tmp); + } } + __syncthreads(); // last warp updates global delta ptr if (warp_id == delta::num_mini_blocks - 1 && lane_id == 0) { diff --git a/cpp/src/io/parquet/page_enc.cu b/cpp/src/io/parquet/page_enc.cu index fad6dac5730..0e76349fd94 100644 --- a/cpp/src/io/parquet/page_enc.cu +++ b/cpp/src/io/parquet/page_enc.cu @@ -247,25 +247,38 @@ struct BitwiseOr { } }; -// T is the parquet physical type -// W is double the bitwidth of T // I is the column type from the input table -// F is a function that computes validity and the src index for a given input position -template +template struct delta_enc { + using output_type = std::conditional_t, zigzag128_t, uleb128_t>; + page_enc_state_s<0>* const s; uint32_t const valid_count; - F& f; uint64_t* const buffer; void* const temp_space; + __device__ thrust::pair calc_idx_and_validity(uint32_t cur_val_idx) + { + size_type const val_idx_in_block = cur_val_idx + threadIdx.x; + size_type const val_idx_in_leaf_col = s->page_start_val + val_idx_in_block; + + uint32_t const is_valid = (val_idx_in_leaf_col < s->col.leaf_column->size() && + val_idx_in_block < s->page.num_leaf_values) + ? s->col.leaf_column->is_valid(val_idx_in_leaf_col) + : 0; + + return {is_valid, val_idx_in_leaf_col}; + } + __device__ uint8_t const* encode() { - __shared__ DeltaBinaryPacker packer; + __shared__ DeltaBinaryPacker packer; auto const t = threadIdx.x; - if (t == 0) { packer.init(s->cur, valid_count, reinterpret_cast(buffer), temp_space); } + if (t == 0) { + packer.init(s->cur, valid_count, reinterpret_cast(buffer), temp_space); + } __syncthreads(); // TODO(ets): in the plain encoder the scaling is a little different for INT32 than INT64. @@ -274,10 +287,10 @@ struct delta_enc { for (uint32_t cur_val_idx = 0; cur_val_idx < s->page.num_leaf_values;) { uint32_t nvals = min(s->page.num_leaf_values - cur_val_idx, delta::block_size); - auto [is_valid, val_idx] = f(cur_val_idx); + auto [is_valid, val_idx] = calc_idx_and_validity(cur_val_idx); cur_val_idx += nvals; - T v = s->col.leaf_column->element(val_idx); + output_type v = s->col.leaf_column->element(val_idx); if (scale < 0) { v /= -scale; } else { @@ -1424,8 +1437,8 @@ __global__ void __launch_bounds__(block_size, 8) switch (dtype_len) { case 8: return col->element(idx) * scale; case 4: return col->element(idx) * scale; - case 2: return col->element(idx) * scale; - default: return col->element(idx) * scale; + case 2: return (col->element(idx) * scale) & 0xffff; + default: return (col->element(idx) * scale) & 0xff; } }(); @@ -1694,8 +1707,7 @@ __global__ void __launch_bounds__(block_size, 8) device_span comp_results) { // block of shared memory for value storage and bit packing - // TODO add constant that's the sum of buffer_size and block_size - __shared__ uint64_t delta_shared[delta::buffer_size + delta::block_size]; + __shared__ uleb128_t delta_shared[delta::buffer_size + delta::block_size]; __shared__ __align__(8) page_enc_state_s<0> state_g; using block_reduce = cub::BlockReduce; __shared__ union { @@ -1767,52 +1779,56 @@ __global__ void __launch_bounds__(block_size, 8) valid_count = block_reduce(temp_storage.reduce_storage).Sum(num_valid); } - auto calc_idx_and_validity = [&](uint32_t cur_val_idx) { - size_type const val_idx_in_block = cur_val_idx + t; - size_type const val_idx_in_leaf_col = s->page_start_val + val_idx_in_block; - - uint32_t const is_valid = (val_idx_in_leaf_col < s->col.leaf_column->size() && - val_idx_in_block < s->page.num_leaf_values) - ? s->col.leaf_column->is_valid(val_idx_in_leaf_col) - : 0; - - return std::make_tuple(is_valid, val_idx_in_leaf_col); - }; - uint8_t const* delta_ptr = nullptr; // this will be the end of delta block pointer if (physical_type == INT32) { switch (dtype_len_in) { case 8: { - delta_enc encoder{ - s, valid_count, calc_idx_and_validity, delta_shared, &temp_storage}; + // only DURATIONS map to 8 bytes, so safe to just use signed here? + delta_enc encoder{s, valid_count, delta_shared, &temp_storage}; delta_ptr = encoder.encode(); break; } case 4: { - delta_enc encoder{ - s, valid_count, calc_idx_and_validity, delta_shared, &temp_storage}; - delta_ptr = encoder.encode(); + if (type_id == type_id::UINT32) { + delta_enc encoder{s, valid_count, delta_shared, &temp_storage}; + delta_ptr = encoder.encode(); + } else { + delta_enc encoder{s, valid_count, delta_shared, &temp_storage}; + delta_ptr = encoder.encode(); + } break; } case 2: { - delta_enc encoder{ - s, valid_count, calc_idx_and_validity, delta_shared, &temp_storage}; - delta_ptr = encoder.encode(); + if (type_id == type_id::UINT16) { + delta_enc encoder{s, valid_count, delta_shared, &temp_storage}; + delta_ptr = encoder.encode(); + } else { + delta_enc encoder{s, valid_count, delta_shared, &temp_storage}; + delta_ptr = encoder.encode(); + } break; } case 1: { - delta_enc encoder{ - s, valid_count, calc_idx_and_validity, delta_shared, &temp_storage}; - delta_ptr = encoder.encode(); + if (type_id == type_id::UINT8) { + delta_enc encoder{s, valid_count, delta_shared, &temp_storage}; + delta_ptr = encoder.encode(); + } else { + delta_enc encoder{s, valid_count, delta_shared, &temp_storage}; + delta_ptr = encoder.encode(); + } break; } default: CUDF_UNREACHABLE("invalid dtype_len_in when encoding DELTA_BINARY_PACKED"); } } else { - delta_enc encoder{ - s, valid_count, calc_idx_and_validity, delta_shared, &temp_storage}; - delta_ptr = encoder.encode(); + if (type_id == type_id::UINT64) { + delta_enc encoder{s, valid_count, delta_shared, &temp_storage}; + delta_ptr = encoder.encode(); + } else { + delta_enc encoder{s, valid_count, delta_shared, &temp_storage}; + delta_ptr = encoder.encode(); + } } finish_page_encode( From 2d41d3abb717457ead53c1e8ef463bd979334cc1 Mon Sep 17 00:00:00 2001 From: seidl Date: Fri, 15 Sep 2023 13:50:46 -0700 Subject: [PATCH 10/37] add const --- cpp/src/io/parquet/delta_enc.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/io/parquet/delta_enc.cuh b/cpp/src/io/parquet/delta_enc.cuh index dab42227827..731955c43ab 100644 --- a/cpp/src/io/parquet/delta_enc.cuh +++ b/cpp/src/io/parquet/delta_enc.cuh @@ -108,7 +108,7 @@ inline __device__ void bitpack_mini_block( // Copy scratch data to final destination auto const available_bytes = util::div_rounding_up_safe(count * nbits, 8U); - auto const scratch_bytes = reinterpret_cast(scratch); + auto const scratch_bytes = reinterpret_cast(scratch); for (uint32_t i = lane_id; i < available_bytes; i += warp_size) { dst[i] = scratch_bytes[i]; From 1d946217ba9401df2e7726dd87e69d541a01636f Mon Sep 17 00:00:00 2001 From: seidl Date: Fri, 15 Sep 2023 13:51:27 -0700 Subject: [PATCH 11/37] add test for delta binary writer --- cpp/tests/io/parquet_test.cpp | 50 ++++++++++++++++++++++++++++------- 1 file changed, 40 insertions(+), 10 deletions(-) diff --git a/cpp/tests/io/parquet_test.cpp b/cpp/tests/io/parquet_test.cpp index 907805d0abd..c381845bd20 100644 --- a/cpp/tests/io/parquet_test.cpp +++ b/cpp/tests/io/parquet_test.cpp @@ -348,6 +348,9 @@ struct ParquetWriterSchemaTest : public ParquetWriterTest { template struct ParquetReaderSourceTest : public ParquetReaderTest {}; +template +struct ParquetWriterDeltaTest : public ParquetWriterTest {}; + // Declare typed test cases // TODO: Replace with `NumericTypes` when unsigned support is added. Issue #5352 using SupportedTypes = cudf::test::Types; @@ -362,6 +365,10 @@ TYPED_TEST_SUITE(ParquetWriterTimestampTypeTest, SupportedTimestampTypes); TYPED_TEST_SUITE(ParquetWriterSchemaTest, cudf::test::AllTypes); using ByteLikeTypes = cudf::test::Types; TYPED_TEST_SUITE(ParquetReaderSourceTest, ByteLikeTypes); +using DeltaDecimalTypes = cudf::test::Types; +using DeltaBinaryTypes = + cudf::test::Concat; +TYPED_TEST_SUITE(ParquetWriterDeltaTest, DeltaBinaryTypes); // Base test fixture for chunked writer tests struct ParquetChunkedWriterTest : public cudf::test::BaseFixture {}; @@ -379,7 +386,6 @@ TYPED_TEST_SUITE(ParquetChunkedWriterNumericTypeTest, SupportedTypes); class ParquetSizedTest : public ::cudf::test::BaseFixtureWithParam {}; // test the allowed bit widths for dictionary encoding -// values chosen to trigger 1, 2, 3, 4, 5, 6, 8, 10, 12, 16, 20, and 24 bit dictionaries INSTANTIATE_TEST_SUITE_P(ParquetDictionaryTest, ParquetSizedTest, testing::Range(1, 25), @@ -540,9 +546,7 @@ TYPED_TEST(ParquetWriterTimestampTypeTest, Timestamps) auto filepath = temp_env->get_temp_filepath("Timestamps.parquet"); cudf::io::parquet_writer_options out_opts = - cudf::io::parquet_writer_options::builder(cudf::io::sink_info{filepath}, expected) - .write_v2_headers(true) - .dictionary_policy(cudf::io::dictionary_policy::NEVER); + cudf::io::parquet_writer_options::builder(cudf::io::sink_info{filepath}, expected); cudf::io::write_parquet(out_opts); cudf::io::parquet_reader_options in_opts = @@ -568,9 +572,7 @@ TYPED_TEST(ParquetWriterTimestampTypeTest, TimestampsWithNulls) auto filepath = temp_env->get_temp_filepath("TimestampsWithNulls.parquet"); cudf::io::parquet_writer_options out_opts = - cudf::io::parquet_writer_options::builder(cudf::io::sink_info{filepath}, expected) - .write_v2_headers(true) - .dictionary_policy(cudf::io::dictionary_policy::NEVER); + cudf::io::parquet_writer_options::builder(cudf::io::sink_info{filepath}, expected); cudf::io::write_parquet(out_opts); cudf::io::parquet_reader_options in_opts = @@ -594,9 +596,7 @@ TYPED_TEST(ParquetWriterTimestampTypeTest, TimestampOverflow) auto filepath = temp_env->get_temp_filepath("ParquetTimestampOverflow.parquet"); cudf::io::parquet_writer_options out_opts = - cudf::io::parquet_writer_options::builder(cudf::io::sink_info{filepath}, expected) - .write_v2_headers(true) - .dictionary_policy(cudf::io::dictionary_policy::NEVER); + cudf::io::parquet_writer_options::builder(cudf::io::sink_info{filepath}, expected); cudf::io::write_parquet(out_opts); cudf::io::parquet_reader_options in_opts = @@ -6738,4 +6738,34 @@ TEST_P(ParquetV2Test, CheckEncodings) } } +TYPED_TEST(ParquetWriterDeltaTest, WriteDeltaBinaryPacked) +{ + using T = TypeParam; + auto col0 = testdata::ascending(); + auto col1 = testdata::descending(); + + auto const expected = table_view{{col0, col1}}; + + auto const filepath = temp_env->get_temp_filepath("DeltaBinaryPacked.parquet"); + cudf::io::parquet_writer_options out_opts = + cudf::io::parquet_writer_options::builder(cudf::io::sink_info{filepath}, expected) + .write_v2_headers(true) + .dictionary_policy(cudf::io::dictionary_policy::NEVER); + cudf::io::write_parquet(out_opts); + + // FIXME: these three types fail whether delta encoding is used or not. Is this a problem + // with the test or is there something wrong in libcudf when writing these types. All three + // of them use ts_scale > 1 + bool constexpr is_failing = std::is_same_v or + std::is_same_v or + std::is_same_v; + + if constexpr (not is_failing) { + cudf::io::parquet_reader_options in_opts = + cudf::io::parquet_reader_options::builder(cudf::io::source_info{filepath}); + auto result = cudf::io::read_parquet(in_opts); + CUDF_TEST_EXPECT_TABLES_EQUAL(expected, result.tbl->view()); + } +} + CUDF_TEST_PROGRAM_MAIN() From 4e2c4a5990fb1fbcc4a079f6ca4baacb15c64d18 Mon Sep 17 00:00:00 2001 From: seidl Date: Fri, 15 Sep 2023 14:15:56 -0700 Subject: [PATCH 12/37] remove unsupported types from delta test --- cpp/tests/io/parquet_test.cpp | 32 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/cpp/tests/io/parquet_test.cpp b/cpp/tests/io/parquet_test.cpp index c381845bd20..aae7630de2f 100644 --- a/cpp/tests/io/parquet_test.cpp +++ b/cpp/tests/io/parquet_test.cpp @@ -365,10 +365,6 @@ TYPED_TEST_SUITE(ParquetWriterTimestampTypeTest, SupportedTimestampTypes); TYPED_TEST_SUITE(ParquetWriterSchemaTest, cudf::test::AllTypes); using ByteLikeTypes = cudf::test::Types; TYPED_TEST_SUITE(ParquetReaderSourceTest, ByteLikeTypes); -using DeltaDecimalTypes = cudf::test::Types; -using DeltaBinaryTypes = - cudf::test::Concat; -TYPED_TEST_SUITE(ParquetWriterDeltaTest, DeltaBinaryTypes); // Base test fixture for chunked writer tests struct ParquetChunkedWriterTest : public cudf::test::BaseFixture {}; @@ -6738,7 +6734,16 @@ TEST_P(ParquetV2Test, CheckEncodings) } } -TYPED_TEST(ParquetWriterDeltaTest, WriteDeltaBinaryPacked) +// removing duration_D, duration_s, and timestamp_s as they don't appear to be supported properly. +// see definition of UnsupportedChronoTypes above. +using DeltaDecimalTypes = cudf::test::Types; +using DeltaBinaryTypes = + cudf::test::Concat; +using SupportedDeltaTestTypes = + cudf::test::RemoveIf, DeltaBinaryTypes>; +TYPED_TEST_SUITE(ParquetWriterDeltaTest, SupportedDeltaTestTypes); + +TYPED_TEST(ParquetWriterDeltaTest, SupportedDeltaTestTypes) { using T = TypeParam; auto col0 = testdata::ascending(); @@ -6753,19 +6758,10 @@ TYPED_TEST(ParquetWriterDeltaTest, WriteDeltaBinaryPacked) .dictionary_policy(cudf::io::dictionary_policy::NEVER); cudf::io::write_parquet(out_opts); - // FIXME: these three types fail whether delta encoding is used or not. Is this a problem - // with the test or is there something wrong in libcudf when writing these types. All three - // of them use ts_scale > 1 - bool constexpr is_failing = std::is_same_v or - std::is_same_v or - std::is_same_v; - - if constexpr (not is_failing) { - cudf::io::parquet_reader_options in_opts = - cudf::io::parquet_reader_options::builder(cudf::io::source_info{filepath}); - auto result = cudf::io::read_parquet(in_opts); - CUDF_TEST_EXPECT_TABLES_EQUAL(expected, result.tbl->view()); - } + cudf::io::parquet_reader_options in_opts = + cudf::io::parquet_reader_options::builder(cudf::io::source_info{filepath}); + auto result = cudf::io::read_parquet(in_opts); + CUDF_TEST_EXPECT_TABLES_EQUAL(expected, result.tbl->view()); } CUDF_TEST_PROGRAM_MAIN() From b3b25dea04890dd1791f79c71ae9a6c5e415c600 Mon Sep 17 00:00:00 2001 From: seidl Date: Fri, 15 Sep 2023 14:20:11 -0700 Subject: [PATCH 13/37] make second column unordered --- cpp/tests/io/parquet_test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/tests/io/parquet_test.cpp b/cpp/tests/io/parquet_test.cpp index aae7630de2f..2a5481de833 100644 --- a/cpp/tests/io/parquet_test.cpp +++ b/cpp/tests/io/parquet_test.cpp @@ -6747,7 +6747,7 @@ TYPED_TEST(ParquetWriterDeltaTest, SupportedDeltaTestTypes) { using T = TypeParam; auto col0 = testdata::ascending(); - auto col1 = testdata::descending(); + auto col1 = testdata::unordered(); auto const expected = table_view{{col0, col1}}; From 00b248fb4feacdadb4f953b5f29952e89b6428f6 Mon Sep 17 00:00:00 2001 From: seidl Date: Mon, 18 Sep 2023 09:37:24 -0700 Subject: [PATCH 14/37] remove leftover experiment --- cpp/src/io/parquet/page_enc.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/io/parquet/page_enc.cu b/cpp/src/io/parquet/page_enc.cu index 0e76349fd94..0ac62c9f056 100644 --- a/cpp/src/io/parquet/page_enc.cu +++ b/cpp/src/io/parquet/page_enc.cu @@ -1437,8 +1437,8 @@ __global__ void __launch_bounds__(block_size, 8) switch (dtype_len) { case 8: return col->element(idx) * scale; case 4: return col->element(idx) * scale; - case 2: return (col->element(idx) * scale) & 0xffff; - default: return (col->element(idx) * scale) & 0xff; + case 2: return col->element(idx) * scale; + default: return col->element(idx) * scale; } }(); From 8b95a786dd65c46b2fa747658f95a7830bc4fbdf Mon Sep 17 00:00:00 2001 From: seidl Date: Thu, 21 Sep 2023 08:26:10 -0700 Subject: [PATCH 15/37] change put_zz128 to void --- cpp/src/io/parquet/delta_enc.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/io/parquet/delta_enc.cuh b/cpp/src/io/parquet/delta_enc.cuh index 731955c43ab..3a9dc1dd295 100644 --- a/cpp/src/io/parquet/delta_enc.cuh +++ b/cpp/src/io/parquet/delta_enc.cuh @@ -36,7 +36,7 @@ inline __device__ void put_uleb128(uint8_t*& p, uleb128_t v) *p++ = v; } -inline __device__ uint8_t* put_zz128(uint8_t*& p, zigzag128_t v) +inline __device__ void put_zz128(uint8_t*& p, zigzag128_t v) { zigzag128_t s = (v < 0); put_uleb128(p, (v ^ -s) * 2 + s); From 01502366f47403f26c9f5b12c0595930d4ad712c Mon Sep 17 00:00:00 2001 From: Ed Seidl Date: Thu, 28 Sep 2023 17:10:37 -0700 Subject: [PATCH 16/37] fix typo Co-authored-by: Vukasin Milovanovic --- cpp/src/io/parquet/delta_enc.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/io/parquet/delta_enc.cuh b/cpp/src/io/parquet/delta_enc.cuh index 3a9dc1dd295..b4b93c04868 100644 --- a/cpp/src/io/parquet/delta_enc.cuh +++ b/cpp/src/io/parquet/delta_enc.cuh @@ -173,7 +173,7 @@ class DeltaBinaryPacker { _values_in_buffer = 0; } - // each thread calls this to add it's current value + // each thread calls this to add its current value inline __device__ void add_value(T value, bool is_valid) { // figure out the correct position for the given value From 64775434881e9cfc056b02815391d8273f9fcab2 Mon Sep 17 00:00:00 2001 From: Ed Seidl Date: Thu, 28 Sep 2023 17:11:56 -0700 Subject: [PATCH 17/37] another typo Co-authored-by: Vukasin Milovanovic --- cpp/src/io/parquet/delta_enc.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/io/parquet/delta_enc.cuh b/cpp/src/io/parquet/delta_enc.cuh index b4b93c04868..d605976ddba 100644 --- a/cpp/src/io/parquet/delta_enc.cuh +++ b/cpp/src/io/parquet/delta_enc.cuh @@ -239,7 +239,7 @@ class DeltaBinaryPacker { if (t == 0) { write_block_header(block_min); } __syncthreads(); - // now each warp encodes it's data...can calculate starting offset with _mb_bits + // now each warp encodes its data...can calculate starting offset with _mb_bits uint8_t* mb_ptr = _dst; switch (warp_id) { case 3: mb_ptr += _mb_bits[2] * delta::values_per_mini_block / 8; [[fallthrough]]; From 60df88e9536bf9a258583835369bb0df476d1f58 Mon Sep 17 00:00:00 2001 From: seidl Date: Thu, 28 Sep 2023 18:11:43 -0700 Subject: [PATCH 18/37] implement suggestions from review --- cpp/src/io/parquet/delta_enc.cuh | 12 ++++++------ cpp/src/io/parquet/page_enc.cu | 30 +++++++++++++++--------------- cpp/src/io/parquet/parquet_gpu.hpp | 12 ++++++------ 3 files changed, 27 insertions(+), 27 deletions(-) diff --git a/cpp/src/io/parquet/delta_enc.cuh b/cpp/src/io/parquet/delta_enc.cuh index d605976ddba..033376e345b 100644 --- a/cpp/src/io/parquet/delta_enc.cuh +++ b/cpp/src/io/parquet/delta_enc.cuh @@ -123,7 +123,7 @@ inline __device__ void bitpack_mini_block( // to be written. // T is the input data type (either zigzag128_t or uleb128_t) template -class DeltaBinaryPacker { +class delta_binary_packer { private: uint8_t* _dst; // sink to dump encoded values to size_type _current_idx; // index of first value in buffer @@ -140,12 +140,12 @@ class DeltaBinaryPacker { void* _bitpack_tmp; // pointer to shared scratch memory used in bitpacking // write the delta binary header. only call from thread 0 - inline __device__ void write_header(T first_value) + inline __device__ void write_header() { delta::put_uleb128(_dst, delta::block_size); delta::put_uleb128(_dst, delta::num_mini_blocks); delta::put_uleb128(_dst, _num_values); - delta::put_zz128(_dst, first_value); + delta::put_zz128(_dst, _buffer[0]); } // write the block header. only call from thread 0 @@ -189,7 +189,7 @@ class DeltaBinaryPacker { _values_in_buffer += num_valid; // if first pass write header if (_current_idx == 0) { - write_header(_buffer[0]); + write_header(); _current_idx = 1; _values_in_buffer -= 1; } @@ -230,7 +230,7 @@ class DeltaBinaryPacker { // for the bitpacking of this warp zigzag128_t const warp_max = delta::warp_reduce(_warp_tmp[warp_id]).Reduce(norm_delta, cub::Max()); - __syncthreads(); + __syncwarp(); if (lane_id == 0) { _mb_bits[warp_id] = sizeof(zigzag128_t) * 8 - __clzll(warp_max); } __syncthreads(); @@ -247,7 +247,7 @@ class DeltaBinaryPacker { case 1: mb_ptr += _mb_bits[0] * delta::values_per_mini_block / 8; } - // encoding happens here....will have to update pack literals to deal with larger numbers + // encoding happens here auto const warp_idx = _current_idx + warp_id * delta::values_per_mini_block; if (warp_idx < _num_values) { auto const num_enc = min(delta::values_per_mini_block, _num_values - warp_idx); diff --git a/cpp/src/io/parquet/page_enc.cu b/cpp/src/io/parquet/page_enc.cu index 58236e4443f..ece92bea44d 100644 --- a/cpp/src/io/parquet/page_enc.cu +++ b/cpp/src/io/parquet/page_enc.cu @@ -272,7 +272,7 @@ struct delta_enc { __device__ uint8_t const* encode() { - __shared__ DeltaBinaryPacker packer; + __shared__ delta_binary_packer packer; auto const t = threadIdx.x; @@ -504,7 +504,7 @@ __global__ void __launch_bounds__(128) __syncwarp(); if (t == 0) { if (not pages.empty()) { - page_g.kernel_mask = EncodeKernelMask::PLAIN; + page_g.kernel_mask = encode_kernel_mask::PLAIN; pages[ck_g.first_page] = page_g; } if (not page_sizes.empty()) { page_sizes[ck_g.first_page] = page_g.max_data_size; } @@ -635,11 +635,11 @@ __global__ void __launch_bounds__(128) if (t == 0) { if (not pages.empty()) { if (is_use_delta) { - page_g.kernel_mask = EncodeKernelMask::DELTA_BINARY; + page_g.kernel_mask = encode_kernel_mask::DELTA_BINARY; } else if (ck_g.use_dictionary || physical_type == BOOLEAN) { - page_g.kernel_mask = EncodeKernelMask::DICTIONARY; + page_g.kernel_mask = encode_kernel_mask::DICTIONARY; } else { - page_g.kernel_mask = EncodeKernelMask::PLAIN; + page_g.kernel_mask = encode_kernel_mask::PLAIN; } pages[ck_g.first_page + num_pages] = page_g; } @@ -1116,7 +1116,7 @@ __device__ auto julian_days_with_time(int64_t v) // for the state buffer. encode kernels that don't use the RLE buffer can get started while // the level data is encoded. // blockDim(128, 1, 1) -template +template __global__ void __launch_bounds__(block_size, 8) gpuEncodePageLevels(device_span pages, bool write_v2_headers) { @@ -1351,7 +1351,7 @@ __global__ void __launch_bounds__(block_size, 8) } __syncthreads(); - if (BitAnd(s->page.kernel_mask, EncodeKernelMask::PLAIN) == 0) { return; } + if (BitAnd(s->page.kernel_mask, encode_kernel_mask::PLAIN) == 0) { return; } // Encode data values __syncthreads(); @@ -1594,7 +1594,7 @@ __global__ void __launch_bounds__(block_size, 8) } __syncthreads(); - if (BitAnd(s->page.kernel_mask, EncodeKernelMask::DICTIONARY) == 0) { return; } + if (BitAnd(s->page.kernel_mask, encode_kernel_mask::DICTIONARY) == 0) { return; } // Encode data values __syncthreads(); @@ -1732,7 +1732,7 @@ __global__ void __launch_bounds__(block_size, 8) } __syncthreads(); - if (BitAnd(s->page.kernel_mask, EncodeKernelMask::DELTA_BINARY) == 0) { return; } + if (BitAnd(s->page.kernel_mask, encode_kernel_mask::DELTA_BINARY) == 0) { return; } // Encode data values __syncthreads(); @@ -2730,23 +2730,23 @@ void EncodePages(device_span pages, // deal with one datatype. int s_idx = 0; - if (BitAnd(kernel_mask, EncodeKernelMask::PLAIN) != 0) { + if (BitAnd(kernel_mask, encode_kernel_mask::PLAIN) != 0) { auto const strm = streams[s_idx++]; - gpuEncodePageLevels + gpuEncodePageLevels <<>>(pages, write_v2_headers); gpuEncodePages<<>>( pages, comp_in, comp_out, comp_results, write_v2_headers); } - if (BitAnd(kernel_mask, EncodeKernelMask::DELTA_BINARY) != 0) { + if (BitAnd(kernel_mask, encode_kernel_mask::DELTA_BINARY) != 0) { auto const strm = streams[s_idx++]; - gpuEncodePageLevels + gpuEncodePageLevels <<>>(pages, write_v2_headers); gpuEncodeDeltaBinaryPages <<>>(pages, comp_in, comp_out, comp_results); } - if (BitAnd(kernel_mask, EncodeKernelMask::DICTIONARY) != 0) { + if (BitAnd(kernel_mask, encode_kernel_mask::DICTIONARY) != 0) { auto const strm = streams[s_idx++]; - gpuEncodePageLevels + gpuEncodePageLevels <<>>(pages, write_v2_headers); gpuEncodeDictPages<<>>( pages, comp_in, comp_out, comp_results, write_v2_headers); diff --git a/cpp/src/io/parquet/parquet_gpu.hpp b/cpp/src/io/parquet/parquet_gpu.hpp index 1645ce87a31..f521c021f04 100644 --- a/cpp/src/io/parquet/parquet_gpu.hpp +++ b/cpp/src/io/parquet/parquet_gpu.hpp @@ -477,7 +477,7 @@ constexpr uint32_t encoding_to_mask(Encoding encoding) * * Used to control which encode kernels to run. */ -enum class EncodeKernelMask { +enum class encode_kernel_mask { PLAIN = (1 << 0), // Run plain encoding kernel DICTIONARY = (1 << 1), // Run dictionary encoding kernel DELTA_BINARY = (1 << 2) // Run DELTA_BINARY_PACKED encoding kernel @@ -541,11 +541,11 @@ struct EncPage { uint32_t num_leaf_values; //!< Values in page. Different from num_rows in case of nested types uint32_t num_values; //!< Number of def/rep level values in page. Includes null/empty elements in //!< non-leaf levels - uint32_t def_lvl_bytes; //!< Number of bytes of encoded definition level data (V2 only) - uint32_t rep_lvl_bytes; //!< Number of bytes of encoded repetition level data (V2 only) - compression_result* comp_res; //!< Ptr to compression result - uint32_t num_nulls; //!< Number of null values (V2 only) (down here for alignment) - EncodeKernelMask kernel_mask; //!< Mask used to control which encoding kernels to run + uint32_t def_lvl_bytes; //!< Number of bytes of encoded definition level data (V2 only) + uint32_t rep_lvl_bytes; //!< Number of bytes of encoded repetition level data (V2 only) + compression_result* comp_res; //!< Ptr to compression result + uint32_t num_nulls; //!< Number of null values (V2 only) (down here for alignment) + encode_kernel_mask kernel_mask; //!< Mask used to control which encoding kernels to run }; /** From e6cc71e90bf24effbe25419ef48f29e400f78e38 Mon Sep 17 00:00:00 2001 From: seidl Date: Fri, 29 Sep 2023 11:39:51 -0700 Subject: [PATCH 19/37] use template foo to get a single BitAnd --- cpp/src/io/parquet/parquet_gpu.hpp | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/cpp/src/io/parquet/parquet_gpu.hpp b/cpp/src/io/parquet/parquet_gpu.hpp index f521c021f04..fefc0db9d70 100644 --- a/cpp/src/io/parquet/parquet_gpu.hpp +++ b/cpp/src/io/parquet/parquet_gpu.hpp @@ -97,25 +97,28 @@ using uleb128_t = uint64_t; using zigzag128_t = int64_t; // TODO this is in C++23 -template +template > struct is_scoped_enum { - static const bool value = - std::is_enum_v and not std::is_convertible_v>; + static const bool value = not std::is_convertible_v>; +}; + +template +struct is_scoped_enum { + static const bool value = false; }; // helpers to do bit operations on scoped enums -template ::value, bool> = true> -constexpr uint32_t BitAnd(Enum a, Enum b) +template ::value and std::is_same_v) or + (is_scoped_enum::value and std::is_same_v) or + (is_scoped_enum::value and std::is_same_v)>* = + nullptr> +constexpr uint32_t BitAnd(T1 a, T2 b) { return static_cast(a) & static_cast(b); } -template ::value, bool> = true> -constexpr uint32_t BitAnd(uint32_t a, Enum b) -{ - return a & static_cast(b); -} - /** * @brief Enums for the flags in the page header */ From 7fff8f05a162de617c085828ca8ff95338bac53e Mon Sep 17 00:00:00 2001 From: seidl Date: Fri, 29 Sep 2023 11:43:03 -0700 Subject: [PATCH 20/37] remove TODO --- cpp/src/io/parquet/parquet_gpu.hpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/cpp/src/io/parquet/parquet_gpu.hpp b/cpp/src/io/parquet/parquet_gpu.hpp index fefc0db9d70..37a00402a3a 100644 --- a/cpp/src/io/parquet/parquet_gpu.hpp +++ b/cpp/src/io/parquet/parquet_gpu.hpp @@ -90,9 +90,7 @@ struct input_column_info { namespace gpu { -// TODO: The delta encodings use ULEB128 integers, but for now we're only -// using max 64 bits. Need to see what the performance impact is of using -// __int128_t rather than int64_t. +// The delta encodings use ULEB128 integers, but parquet only uses max 64 bits. using uleb128_t = uint64_t; using zigzag128_t = int64_t; From 8a20836a0b82992d384bd907cd8c57cad79681b0 Mon Sep 17 00:00:00 2001 From: seidl Date: Fri, 29 Sep 2023 12:07:09 -0700 Subject: [PATCH 21/37] add an ifdef around is_scoped_enum --- cpp/src/io/parquet/parquet_gpu.hpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/cpp/src/io/parquet/parquet_gpu.hpp b/cpp/src/io/parquet/parquet_gpu.hpp index 37a00402a3a..d412a5015e3 100644 --- a/cpp/src/io/parquet/parquet_gpu.hpp +++ b/cpp/src/io/parquet/parquet_gpu.hpp @@ -94,7 +94,8 @@ namespace gpu { using uleb128_t = uint64_t; using zigzag128_t = int64_t; -// TODO this is in C++23 +// this is in C++23 +#if !defined(__cpp_lib_is_scoped_enum) template > struct is_scoped_enum { static const bool value = not std::is_convertible_v>; @@ -104,6 +105,9 @@ template struct is_scoped_enum { static const bool value = false; }; +#else +using std::is_scoped_enum; +#endif // helpers to do bit operations on scoped enums template Date: Fri, 29 Sep 2023 12:19:12 -0700 Subject: [PATCH 22/37] avoid UB when calculating deltas --- cpp/src/io/parquet/delta_enc.cuh | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/cpp/src/io/parquet/delta_enc.cuh b/cpp/src/io/parquet/delta_enc.cuh index 033376e345b..c88eab50ab9 100644 --- a/cpp/src/io/parquet/delta_enc.cuh +++ b/cpp/src/io/parquet/delta_enc.cuh @@ -156,6 +156,12 @@ class delta_binary_packer { _dst += 4; } + // signed subtraction with defined wrapping behavior + inline __device__ zigzag128_t subtract(zigzag128_t a, zigzag128_t b) + { + return static_cast(static_cast(a) - static_cast(b)); + } + public: inline __device__ auto num_values() const { return _num_values; } @@ -212,10 +218,10 @@ class delta_binary_packer { if (_values_in_buffer <= 0) { return _dst; } // calculate delta for this thread - size_type const idx = _current_idx + t; - zigzag128_t const delta = - idx < _num_values ? _buffer[delta::rolling_idx(idx)] - _buffer[delta::rolling_idx(idx - 1)] - : std::numeric_limits::max(); + size_type const idx = _current_idx + t; + zigzag128_t const delta = idx < _num_values ? subtract(_buffer[delta::rolling_idx(idx)], + _buffer[delta::rolling_idx(idx - 1)]) + : std::numeric_limits::max(); // find min delta for the block auto const min_delta = delta::block_reduce(*_block_tmp).Reduce(delta, cub::Min()); @@ -224,7 +230,7 @@ class delta_binary_packer { __syncthreads(); // compute frame of reference for the block - uleb128_t const norm_delta = idx < _num_values ? delta - block_min : 0; + uleb128_t const norm_delta = idx < _num_values ? subtract(delta, block_min) : 0; // get max normalized delta for each warp, and use that to determine how many bits to use // for the bitpacking of this warp From 068a01750ec8af9eaa61c2955bac80b396928681 Mon Sep 17 00:00:00 2001 From: seidl Date: Fri, 29 Sep 2023 17:10:22 -0700 Subject: [PATCH 23/37] remove unnecessary sync --- cpp/src/io/parquet/delta_enc.cuh | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/src/io/parquet/delta_enc.cuh b/cpp/src/io/parquet/delta_enc.cuh index c88eab50ab9..bdaf524e469 100644 --- a/cpp/src/io/parquet/delta_enc.cuh +++ b/cpp/src/io/parquet/delta_enc.cuh @@ -87,7 +87,6 @@ inline __device__ void bitpack_mini_block( val >>= 8; } } - __syncwarp(); return; } From 6be661e8fa7863b8c1399f7128c15c5251dfd8a3 Mon Sep 17 00:00:00 2001 From: Ed Seidl Date: Sat, 30 Sep 2023 16:38:27 -0700 Subject: [PATCH 24/37] lost some constants somewhere --- cpp/src/io/parquet/page_enc.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/io/parquet/page_enc.cu b/cpp/src/io/parquet/page_enc.cu index ece92bea44d..2a9df5e3fb2 100644 --- a/cpp/src/io/parquet/page_enc.cu +++ b/cpp/src/io/parquet/page_enc.cu @@ -1688,8 +1688,8 @@ __global__ void __launch_bounds__(block_size, 8) // save RLE length if necessary if (s->rle_len_pos != nullptr && t < 32) { // size doesn't include the 4 bytes for the length - auto const rle_size = static_cast(s->cur - s->rle_len_pos) - 4; - if (t < 4) { s->rle_len_pos[t] = rle_size >> (t * 8); } + auto const rle_size = static_cast(s->cur - s->rle_len_pos) - RLE_LENGTH_FIELD_LEN; + if (t < RLE_LENGTH_FIELD_LEN) { s->rle_len_pos[t] = rle_size >> (t * 8); } __syncwarp(); } From bd19be0870ccd4f6413345c3deeb83eb6d423f60 Mon Sep 17 00:00:00 2001 From: seidl Date: Mon, 2 Oct 2023 11:07:56 -0700 Subject: [PATCH 25/37] address some review comments --- cpp/src/io/parquet/page_enc.cu | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/cpp/src/io/parquet/page_enc.cu b/cpp/src/io/parquet/page_enc.cu index 2a9df5e3fb2..7f8af2d5735 100644 --- a/cpp/src/io/parquet/page_enc.cu +++ b/cpp/src/io/parquet/page_enc.cu @@ -257,15 +257,15 @@ struct delta_enc { uint64_t* const buffer; void* const temp_space; - __device__ thrust::pair calc_idx_and_validity(uint32_t cur_val_idx) + __device__ thrust::pair calc_validity_and_idx(uint32_t cur_val_idx) { size_type const val_idx_in_block = cur_val_idx + threadIdx.x; size_type const val_idx_in_leaf_col = s->page_start_val + val_idx_in_block; - uint32_t const is_valid = (val_idx_in_leaf_col < s->col.leaf_column->size() && - val_idx_in_block < s->page.num_leaf_values) - ? s->col.leaf_column->is_valid(val_idx_in_leaf_col) - : 0; + bool const is_valid = (val_idx_in_leaf_col < s->col.leaf_column->size() && + val_idx_in_block < s->page.num_leaf_values) + ? s->col.leaf_column->is_valid(val_idx_in_leaf_col) + : false; return {is_valid, val_idx_in_leaf_col}; } @@ -287,7 +287,7 @@ struct delta_enc { for (uint32_t cur_val_idx = 0; cur_val_idx < s->page.num_leaf_values;) { uint32_t nvals = min(s->page.num_leaf_values - cur_val_idx, delta::block_size); - auto [is_valid, val_idx] = calc_idx_and_validity(cur_val_idx); + auto [is_valid, val_idx] = calc_validity_and_idx(cur_val_idx); cur_val_idx += nvals; output_type v = s->col.leaf_column->element(val_idx); From 84f02a760764beb8433fec8d2a733bf691f7d178 Mon Sep 17 00:00:00 2001 From: seidl Date: Mon, 2 Oct 2023 11:45:54 -0700 Subject: [PATCH 26/37] reduce register pressure a bit --- cpp/src/io/parquet/delta_enc.cuh | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/cpp/src/io/parquet/delta_enc.cuh b/cpp/src/io/parquet/delta_enc.cuh index bdaf524e469..35e1c525213 100644 --- a/cpp/src/io/parquet/delta_enc.cuh +++ b/cpp/src/io/parquet/delta_enc.cuh @@ -245,12 +245,13 @@ class delta_binary_packer { __syncthreads(); // now each warp encodes its data...can calculate starting offset with _mb_bits - uint8_t* mb_ptr = _dst; + int cumulative_bits = 0; switch (warp_id) { - case 3: mb_ptr += _mb_bits[2] * delta::values_per_mini_block / 8; [[fallthrough]]; - case 2: mb_ptr += _mb_bits[1] * delta::values_per_mini_block / 8; [[fallthrough]]; - case 1: mb_ptr += _mb_bits[0] * delta::values_per_mini_block / 8; + case 3: cumulative_bits += _mb_bits[2]; [[fallthrough]]; + case 2: cumulative_bits += _mb_bits[1]; [[fallthrough]]; + case 1: cumulative_bits += _mb_bits[0]; } + uint8_t* const mb_ptr = _dst + cumulative_bits * delta::values_per_mini_block / 8; // encoding happens here auto const warp_idx = _current_idx + warp_id * delta::values_per_mini_block; From 806bca249b0f768985f7bd63f1fcfbd9bb8e2973 Mon Sep 17 00:00:00 2001 From: seidl Date: Mon, 2 Oct 2023 12:21:14 -0700 Subject: [PATCH 27/37] add some comments and sanity checks --- cpp/src/io/parquet/delta_enc.cuh | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/cpp/src/io/parquet/delta_enc.cuh b/cpp/src/io/parquet/delta_enc.cuh index 35e1c525213..e42ca43c2b4 100644 --- a/cpp/src/io/parquet/delta_enc.cuh +++ b/cpp/src/io/parquet/delta_enc.cuh @@ -44,11 +44,17 @@ inline __device__ void put_zz128(uint8_t*& p, zigzag128_t v) // a block size of 128, with 4 mini-blocks of 32 values each fits nicely without consuming // too much shared memory. +// the parquet spec requires block_size to be a multiple of 128, and values_per_mini_block +// to be a multiple of 32. constexpr int block_size = 128; constexpr int num_mini_blocks = 4; constexpr int values_per_mini_block = block_size / num_mini_blocks; constexpr int buffer_size = 2 * block_size; +// extra sanity checks to enforce compliance with the parquet specification +static_assert(block_size % 128 == 0); +static_assert(values_per_mini_block % 32 == 0); + using block_reduce = cub::BlockReduce; using warp_reduce = cub::WarpReduce; using index_scan = cub::BlockScan; @@ -245,6 +251,8 @@ class delta_binary_packer { __syncthreads(); // now each warp encodes its data...can calculate starting offset with _mb_bits + // NOTE: using a switch here rather than a loop because the compiler produces code that + // uses fewer registers. int cumulative_bits = 0; switch (warp_id) { case 3: cumulative_bits += _mb_bits[2]; [[fallthrough]]; From ffe2d907905f27be67af4f0988d586d0315adcb3 Mon Sep 17 00:00:00 2001 From: seidl Date: Mon, 2 Oct 2023 12:22:44 -0700 Subject: [PATCH 28/37] add test with sliced table --- cpp/tests/io/parquet_test.cpp | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/cpp/tests/io/parquet_test.cpp b/cpp/tests/io/parquet_test.cpp index cbc306ec63d..32ef2375bca 100644 --- a/cpp/tests/io/parquet_test.cpp +++ b/cpp/tests/io/parquet_test.cpp @@ -6764,4 +6764,28 @@ TYPED_TEST(ParquetWriterDeltaTest, SupportedDeltaTestTypes) CUDF_TEST_EXPECT_TABLES_EQUAL(expected, result.tbl->view()); } +TYPED_TEST(ParquetWriterDeltaTest, SupportedDeltaTestTypesSliced) +{ + using T = TypeParam; + constexpr int num_rows = 4'000; + auto col0 = testdata::ascending(); + auto col1 = testdata::unordered(); + + auto const expected = table_view{{col0, col1}}; + auto expected_slice = cudf::slice(expected, {num_rows, 2 * num_rows}); + ASSERT_EQ(expected_slice[0].num_rows(), num_rows); + + auto const filepath = temp_env->get_temp_filepath("DeltaBinaryPackedSliced.parquet"); + cudf::io::parquet_writer_options out_opts = + cudf::io::parquet_writer_options::builder(cudf::io::sink_info{filepath}, expected_slice) + .write_v2_headers(true) + .dictionary_policy(cudf::io::dictionary_policy::NEVER); + cudf::io::write_parquet(out_opts); + + cudf::io::parquet_reader_options in_opts = + cudf::io::parquet_reader_options::builder(cudf::io::source_info{filepath}); + auto result = cudf::io::read_parquet(in_opts); + CUDF_TEST_EXPECT_TABLES_EQUAL(expected_slice, result.tbl->view()); +} + CUDF_TEST_PROGRAM_MAIN() From 4f47bddfd24884402a8833303636ea80f5167548 Mon Sep 17 00:00:00 2001 From: seidl Date: Mon, 2 Oct 2023 13:21:24 -0700 Subject: [PATCH 29/37] make comment match reality --- cpp/src/io/parquet/page_enc.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/src/io/parquet/page_enc.cu b/cpp/src/io/parquet/page_enc.cu index 7f8af2d5735..18b84c4a020 100644 --- a/cpp/src/io/parquet/page_enc.cu +++ b/cpp/src/io/parquet/page_enc.cu @@ -403,13 +403,13 @@ __device__ size_t delta_data_len(parquet::Type physical_type, auto const vals_per_block = delta::block_size; size_t const num_blocks = util::div_rounding_up_unsafe(num_values, vals_per_block); - // need max dtype_len_in + 1 bytes for min_delta + // need max dtype_len + 1 bytes for min_delta // one byte per mini block for the bitwidth - // and block_size * dtype_len_in bytes for the actual encoded data + // and block_size * dtype_len bytes for the actual encoded data auto const block_size = dtype_len + 1 + delta::num_mini_blocks + vals_per_block * dtype_len; // delta header is 2 bytes for the block_size, 1 byte for number of mini-blocks, - // max 5 bytes for number of values, and max dtype_len_in + 1 for first value. + // max 5 bytes for number of values, and max dtype_len + 1 for first value. auto const header_size = 2 + 1 + 5 + dtype_len + 1; return header_size + num_blocks * block_size; From c60a4a992d1cb3129e586c14aa7102f24dc4aedf Mon Sep 17 00:00:00 2001 From: seidl Date: Mon, 2 Oct 2023 13:35:08 -0700 Subject: [PATCH 30/37] remove some template declarations --- cpp/src/io/parquet/page_enc.cu | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/cpp/src/io/parquet/page_enc.cu b/cpp/src/io/parquet/page_enc.cu index 18b84c4a020..fcb8953e600 100644 --- a/cpp/src/io/parquet/page_enc.cu +++ b/cpp/src/io/parquet/page_enc.cu @@ -100,6 +100,8 @@ struct page_enc_state_s { uint32_t vals[rle_buf_size]; }; +using rle_page_enc_state_s = page_enc_state_s; + /** * @brief Returns the size of the type in the Parquet file. */ @@ -905,9 +907,8 @@ inline __device__ void PackLiterals( * @param[in] flush nonzero if last batch in block * @param[in] t thread id (0..127) */ -template static __device__ void RleEncode( - state_buf* s, uint32_t numvals, uint32_t nbits, uint32_t flush, uint32_t t) + rle_page_enc_state_s* s, uint32_t numvals, uint32_t nbits, uint32_t flush, uint32_t t) { using cudf::detail::warp_size; auto const lane_id = t % warp_size; @@ -1050,8 +1051,10 @@ static __device__ void RleEncode( * @param[in] flush nonzero if last batch in block * @param[in] t thread id (0..127) */ -template -static __device__ void PlainBoolEncode(state_buf* s, uint32_t numvals, uint32_t flush, uint32_t t) +static __device__ void PlainBoolEncode(rle_page_enc_state_s* s, + uint32_t numvals, + uint32_t flush, + uint32_t t) { uint32_t rle_pos = s->rle_pos; uint8_t* dst = s->rle_out; @@ -1120,13 +1123,13 @@ template __global__ void __launch_bounds__(block_size, 8) gpuEncodePageLevels(device_span pages, bool write_v2_headers) { - __shared__ __align__(8) page_enc_state_s state_g; + __shared__ __align__(8) rle_page_enc_state_s state_g; auto* const s = &state_g; uint32_t const t = threadIdx.x; if (t == 0) { - state_g = page_enc_state_s{}; + state_g = rle_page_enc_state_s{}; s->page = pages[blockIdx.x]; s->ck = *s->page.chunk; s->col = *s->ck.col_desc; @@ -1566,7 +1569,7 @@ __global__ void __launch_bounds__(block_size, 8) device_span comp_results, bool write_v2_headers) { - __shared__ __align__(8) page_enc_state_s state_g; + __shared__ __align__(8) rle_page_enc_state_s state_g; using block_reduce = cub::BlockReduce; using block_scan = cub::BlockScan; __shared__ union { @@ -1578,7 +1581,7 @@ __global__ void __launch_bounds__(block_size, 8) uint32_t t = threadIdx.x; if (t == 0) { - state_g = page_enc_state_s{}; + state_g = rle_page_enc_state_s{}; s->page = pages[blockIdx.x]; s->ck = *s->page.chunk; s->col = *s->ck.col_desc; From 9049b7d287cd3a7efc1588bc155e3287af7a8184 Mon Sep 17 00:00:00 2001 From: Ed Seidl Date: Mon, 2 Oct 2023 13:36:00 -0700 Subject: [PATCH 31/37] implement suggestion from review Co-authored-by: Vukasin Milovanovic --- cpp/src/io/parquet/page_enc.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/src/io/parquet/page_enc.cu b/cpp/src/io/parquet/page_enc.cu index 7f8af2d5735..8278752e6a8 100644 --- a/cpp/src/io/parquet/page_enc.cu +++ b/cpp/src/io/parquet/page_enc.cu @@ -1287,7 +1287,6 @@ __device__ void finish_page_encode(state_buf* s, uint8_t const* const base = s->page.page_data + s->page.max_hdr_size; auto const actual_data_size = static_cast(end_ptr - base); if (actual_data_size > s->page.max_data_size) { - printf("data corruption %d %d\n", actual_data_size, s->page.max_data_size); CUDF_UNREACHABLE("detected possible page data corruption"); } s->page.max_data_size = actual_data_size; From 2a9d5d91164d6d2c157b90e3c2808f4807f8d948 Mon Sep 17 00:00:00 2001 From: seidl Date: Mon, 2 Oct 2023 13:41:28 -0700 Subject: [PATCH 32/37] remove another template param --- cpp/src/io/parquet/page_enc.cu | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/cpp/src/io/parquet/page_enc.cu b/cpp/src/io/parquet/page_enc.cu index 01f94709fb3..d85d4f85cd2 100644 --- a/cpp/src/io/parquet/page_enc.cu +++ b/cpp/src/io/parquet/page_enc.cu @@ -1119,9 +1119,11 @@ __device__ auto julian_days_with_time(int64_t v) // for the state buffer. encode kernels that don't use the RLE buffer can get started while // the level data is encoded. // blockDim(128, 1, 1) -template +template __global__ void __launch_bounds__(block_size, 8) - gpuEncodePageLevels(device_span pages, bool write_v2_headers) + gpuEncodePageLevels(device_span pages, + bool write_v2_headers, + encode_kernel_mask kernel_mask) { __shared__ __align__(8) rle_page_enc_state_s state_g; @@ -2734,22 +2736,22 @@ void EncodePages(device_span pages, int s_idx = 0; if (BitAnd(kernel_mask, encode_kernel_mask::PLAIN) != 0) { auto const strm = streams[s_idx++]; - gpuEncodePageLevels - <<>>(pages, write_v2_headers); + gpuEncodePageLevels<<>>( + pages, write_v2_headers, encode_kernel_mask::PLAIN); gpuEncodePages<<>>( pages, comp_in, comp_out, comp_results, write_v2_headers); } if (BitAnd(kernel_mask, encode_kernel_mask::DELTA_BINARY) != 0) { auto const strm = streams[s_idx++]; - gpuEncodePageLevels - <<>>(pages, write_v2_headers); + gpuEncodePageLevels<<>>( + pages, write_v2_headers, encode_kernel_mask::DELTA_BINARY); gpuEncodeDeltaBinaryPages <<>>(pages, comp_in, comp_out, comp_results); } if (BitAnd(kernel_mask, encode_kernel_mask::DICTIONARY) != 0) { auto const strm = streams[s_idx++]; - gpuEncodePageLevels - <<>>(pages, write_v2_headers); + gpuEncodePageLevels<<>>( + pages, write_v2_headers, encode_kernel_mask::DICTIONARY); gpuEncodeDictPages<<>>( pages, comp_in, comp_out, comp_results, write_v2_headers); } From 8604b533d34fe50be8e82ce50d10a7424d74076e Mon Sep 17 00:00:00 2001 From: seidl Date: Mon, 2 Oct 2023 14:03:47 -0700 Subject: [PATCH 33/37] replace encoder struct with function --- cpp/src/io/parquet/page_enc.cu | 101 +++++++++++++-------------------- 1 file changed, 40 insertions(+), 61 deletions(-) diff --git a/cpp/src/io/parquet/page_enc.cu b/cpp/src/io/parquet/page_enc.cu index d85d4f85cd2..d13c8ee8ff0 100644 --- a/cpp/src/io/parquet/page_enc.cu +++ b/cpp/src/io/parquet/page_enc.cu @@ -251,59 +251,47 @@ struct BitwiseOr { // I is the column type from the input table template -struct delta_enc { +__device__ uint8_t const* delta_encode(page_enc_state_s<0>* s, + uint32_t valid_count, + uint64_t* buffer, + void* temp_space) +{ using output_type = std::conditional_t, zigzag128_t, uleb128_t>; + __shared__ delta_binary_packer packer; - page_enc_state_s<0>* const s; - uint32_t const valid_count; - uint64_t* const buffer; - void* const temp_space; - - __device__ thrust::pair calc_validity_and_idx(uint32_t cur_val_idx) - { - size_type const val_idx_in_block = cur_val_idx + threadIdx.x; - size_type const val_idx_in_leaf_col = s->page_start_val + val_idx_in_block; - - bool const is_valid = (val_idx_in_leaf_col < s->col.leaf_column->size() && - val_idx_in_block < s->page.num_leaf_values) - ? s->col.leaf_column->is_valid(val_idx_in_leaf_col) - : false; - - return {is_valid, val_idx_in_leaf_col}; + auto const t = threadIdx.x; + if (t == 0) { + packer.init(s->cur, valid_count, reinterpret_cast(buffer), temp_space); } + __syncthreads(); - __device__ uint8_t const* encode() - { - __shared__ delta_binary_packer packer; - - auto const t = threadIdx.x; + // TODO(ets): in the plain encoder the scaling is a little different for INT32 than INT64. + // might need to modify this if there's a big performance hit in the 32-bit case. + int32_t const scale = s->col.ts_scale == 0 ? 1 : s->col.ts_scale; + for (uint32_t cur_val_idx = 0; cur_val_idx < s->page.num_leaf_values;) { + uint32_t const nvals = min(s->page.num_leaf_values - cur_val_idx, delta::block_size); - if (t == 0) { - packer.init(s->cur, valid_count, reinterpret_cast(buffer), temp_space); - } - __syncthreads(); + size_type const val_idx_in_block = cur_val_idx + t; + size_type const val_idx = s->page_start_val + val_idx_in_block; - // TODO(ets): in the plain encoder the scaling is a little different for INT32 than INT64. - // might need to modify this if there's a big performance hit in the 32-bit case. - int32_t const scale = s->col.ts_scale == 0 ? 1 : s->col.ts_scale; - for (uint32_t cur_val_idx = 0; cur_val_idx < s->page.num_leaf_values;) { - uint32_t nvals = min(s->page.num_leaf_values - cur_val_idx, delta::block_size); + bool const is_valid = + (val_idx < s->col.leaf_column->size() && val_idx_in_block < s->page.num_leaf_values) + ? s->col.leaf_column->is_valid(val_idx) + : false; - auto [is_valid, val_idx] = calc_validity_and_idx(cur_val_idx); - cur_val_idx += nvals; + cur_val_idx += nvals; - output_type v = s->col.leaf_column->element(val_idx); - if (scale < 0) { - v /= -scale; - } else { - v *= scale; - } - packer.add_value(v, is_valid); + output_type v = s->col.leaf_column->element(val_idx); + if (scale < 0) { + v /= -scale; + } else { + v *= scale; } - - return packer.flush(); + packer.add_value(v, is_valid); } -}; + + return packer.flush(); +} } // anonymous namespace @@ -1789,37 +1777,30 @@ __global__ void __launch_bounds__(block_size, 8) switch (dtype_len_in) { case 8: { // only DURATIONS map to 8 bytes, so safe to just use signed here? - delta_enc encoder{s, valid_count, delta_shared, &temp_storage}; - delta_ptr = encoder.encode(); + delta_ptr = delta_encode(s, valid_count, delta_shared, &temp_storage); break; } case 4: { if (type_id == type_id::UINT32) { - delta_enc encoder{s, valid_count, delta_shared, &temp_storage}; - delta_ptr = encoder.encode(); + delta_ptr = delta_encode(s, valid_count, delta_shared, &temp_storage); } else { - delta_enc encoder{s, valid_count, delta_shared, &temp_storage}; - delta_ptr = encoder.encode(); + delta_ptr = delta_encode(s, valid_count, delta_shared, &temp_storage); } break; } case 2: { if (type_id == type_id::UINT16) { - delta_enc encoder{s, valid_count, delta_shared, &temp_storage}; - delta_ptr = encoder.encode(); + delta_ptr = delta_encode(s, valid_count, delta_shared, &temp_storage); } else { - delta_enc encoder{s, valid_count, delta_shared, &temp_storage}; - delta_ptr = encoder.encode(); + delta_ptr = delta_encode(s, valid_count, delta_shared, &temp_storage); } break; } case 1: { if (type_id == type_id::UINT8) { - delta_enc encoder{s, valid_count, delta_shared, &temp_storage}; - delta_ptr = encoder.encode(); + delta_ptr = delta_encode(s, valid_count, delta_shared, &temp_storage); } else { - delta_enc encoder{s, valid_count, delta_shared, &temp_storage}; - delta_ptr = encoder.encode(); + delta_ptr = delta_encode(s, valid_count, delta_shared, &temp_storage); } break; } @@ -1827,11 +1808,9 @@ __global__ void __launch_bounds__(block_size, 8) } } else { if (type_id == type_id::UINT64) { - delta_enc encoder{s, valid_count, delta_shared, &temp_storage}; - delta_ptr = encoder.encode(); + delta_ptr = delta_encode(s, valid_count, delta_shared, &temp_storage); } else { - delta_enc encoder{s, valid_count, delta_shared, &temp_storage}; - delta_ptr = encoder.encode(); + delta_ptr = delta_encode(s, valid_count, delta_shared, &temp_storage); } } From 307017a4f9cec097e0f092f6e73c82ce443d4d21 Mon Sep 17 00:00:00 2001 From: seidl Date: Mon, 2 Oct 2023 14:45:57 -0700 Subject: [PATCH 34/37] add sliced list test --- cpp/tests/io/parquet_test.cpp | 44 +++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/cpp/tests/io/parquet_test.cpp b/cpp/tests/io/parquet_test.cpp index 32ef2375bca..1ff133ddaa7 100644 --- a/cpp/tests/io/parquet_test.cpp +++ b/cpp/tests/io/parquet_test.cpp @@ -6788,4 +6788,48 @@ TYPED_TEST(ParquetWriterDeltaTest, SupportedDeltaTestTypesSliced) CUDF_TEST_EXPECT_TABLES_EQUAL(expected_slice, result.tbl->view()); } +TYPED_TEST(ParquetWriterDeltaTest, SupportedDeltaListSliced) +{ + using T = TypeParam; + + constexpr int num_slice = 4'000; + constexpr int num_rows = 32 * 1024; + + std::mt19937 gen(6542); + std::bernoulli_distribution bn(0.7f); + auto valids = + cudf::detail::make_counting_transform_iterator(0, [&](int index) { return bn(gen); }); + auto values = thrust::make_counting_iterator(0); + + // list + constexpr int vals_per_row = 4; + auto c1_offset_iter = cudf::detail::make_counting_transform_iterator( + 0, [vals_per_row](cudf::size_type idx) { return idx * vals_per_row; }); + cudf::test::fixed_width_column_wrapper c1_offsets(c1_offset_iter, + c1_offset_iter + num_rows + 1); + cudf::test::fixed_width_column_wrapper c1_vals( + values, values + (num_rows * vals_per_row), valids); + auto [null_mask, null_count] = cudf::test::detail::make_null_mask(valids, valids + num_rows); + + auto _c1 = cudf::make_lists_column( + num_rows, c1_offsets.release(), c1_vals.release(), null_count, std::move(null_mask)); + auto c1 = cudf::purge_nonempty_nulls(*_c1); + + auto const expected = table_view{{*c1}}; + auto expected_slice = cudf::slice(expected, {num_slice, 2 * num_slice}); + ASSERT_EQ(expected_slice[0].num_rows(), num_slice); + + auto const filepath = temp_env->get_temp_filepath("DeltaBinaryPackedListSliced.parquet"); + cudf::io::parquet_writer_options out_opts = + cudf::io::parquet_writer_options::builder(cudf::io::sink_info{filepath}, expected_slice) + .write_v2_headers(true) + .dictionary_policy(cudf::io::dictionary_policy::NEVER); + cudf::io::write_parquet(out_opts); + + cudf::io::parquet_reader_options in_opts = + cudf::io::parquet_reader_options::builder(cudf::io::source_info{filepath}); + auto result = cudf::io::read_parquet(in_opts); + CUDF_TEST_EXPECT_TABLES_EQUAL(expected_slice, result.tbl->view()); +} + CUDF_TEST_PROGRAM_MAIN() From 3db131b6745f69f1d7bbd7eae5cd03b1d42464f5 Mon Sep 17 00:00:00 2001 From: seidl Date: Mon, 9 Oct 2023 10:49:12 -0700 Subject: [PATCH 35/37] finish merge --- cpp/src/io/parquet/delta_enc.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/io/parquet/delta_enc.cuh b/cpp/src/io/parquet/delta_enc.cuh index e42ca43c2b4..ba59d0d44d4 100644 --- a/cpp/src/io/parquet/delta_enc.cuh +++ b/cpp/src/io/parquet/delta_enc.cuh @@ -23,7 +23,7 @@ #include -namespace cudf::io::parquet::gpu { +namespace cudf::io::parquet::detail { namespace delta { @@ -287,4 +287,4 @@ class delta_binary_packer { } }; -} // namespace cudf::io::parquet::gpu +} // namespace cudf::io::parquet::detail From 5db63128ff19a22b18c9cdab22a3ac82991f29e6 Mon Sep 17 00:00:00 2001 From: seidl Date: Wed, 18 Oct 2023 09:54:57 -0700 Subject: [PATCH 36/37] clean ups suggested in review --- cpp/src/io/parquet/delta_enc.cuh | 50 ++++++++++++++++---------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/cpp/src/io/parquet/delta_enc.cuh b/cpp/src/io/parquet/delta_enc.cuh index ba59d0d44d4..28f8cdfe2c1 100644 --- a/cpp/src/io/parquet/delta_enc.cuh +++ b/cpp/src/io/parquet/delta_enc.cuh @@ -30,10 +30,10 @@ namespace delta { inline __device__ void put_uleb128(uint8_t*& p, uleb128_t v) { while (v > 0x7f) { - *p++ = v | 0x80; + *(p++) = v | 0x80; v >>= 7; } - *p++ = v; + *(p++) = v; } inline __device__ void put_zz128(uint8_t*& p, zigzag128_t v) @@ -42,16 +42,16 @@ inline __device__ void put_zz128(uint8_t*& p, zigzag128_t v) put_uleb128(p, (v ^ -s) * 2 + s); } -// a block size of 128, with 4 mini-blocks of 32 values each fits nicely without consuming +// A block size of 128, with 4 mini-blocks of 32 values each fits nicely without consuming // too much shared memory. -// the parquet spec requires block_size to be a multiple of 128, and values_per_mini_block +// The parquet spec requires block_size to be a multiple of 128, and values_per_mini_block // to be a multiple of 32. constexpr int block_size = 128; constexpr int num_mini_blocks = 4; constexpr int values_per_mini_block = block_size / num_mini_blocks; constexpr int buffer_size = 2 * block_size; -// extra sanity checks to enforce compliance with the parquet specification +// An extra sanity checks to enforce compliance with the parquet specification. static_assert(block_size % 128 == 0); static_assert(values_per_mini_block % 32 == 0); @@ -61,7 +61,7 @@ using index_scan = cub::BlockScan; constexpr int rolling_idx(int index) { return rolling_index(index); } -// version of bit packer that can handle up to 64 bits values. +// Version of bit packer that can handle up to 64 bits values. // T is the type to use for processing. if nbits <= 32 use uint32_t, otherwise unsigned long long // (not uint64_t because of atomicOr's typing). allowing this to be selectable since there's a // measurable impact to using the wider types. @@ -97,7 +97,7 @@ inline __device__ void bitpack_mini_block( } if (lane_id <= count) { - // shift symbol left by up to mask bits + // Shift symbol left by up to mask bits. wide_type v2 = val; v2 <<= (lane_id * nbits) & mask; @@ -105,13 +105,13 @@ inline __device__ void bitpack_mini_block( scratch_type v1[2]; memcpy(&v1, &v2, sizeof(wide_type)); - // Atomically write result to scratch + // Atomically write result to scratch. if (v1[0]) { atomicOr(scratch + ((lane_id * nbits) / div), v1[0]); } if (v1[1]) { atomicOr(scratch + ((lane_id * nbits) / div) + 1, v1[1]); } } __syncwarp(); - // Copy scratch data to final destination + // Copy scratch data to final destination. auto const available_bytes = util::div_rounding_up_safe(count * nbits, 8U); auto const scratch_bytes = reinterpret_cast(scratch); @@ -126,15 +126,15 @@ inline __device__ void bitpack_mini_block( // Object used to turn a stream of integers into a DELTA_BINARY_PACKED stream. This takes as input // 128 values with validity at a time, saving them until there are enough values for a block // to be written. -// T is the input data type (either zigzag128_t or uleb128_t) +// T is the input data type (either zigzag128_t or uleb128_t). template class delta_binary_packer { private: uint8_t* _dst; // sink to dump encoded values to + T* _buffer; // buffer to store values to be encoded size_type _current_idx; // index of first value in buffer uint32_t _num_values; // total number of values to encode size_type _values_in_buffer; // current number of values stored in _buffer - T* _buffer; // buffer to store values to be encoded uint8_t _mb_bits[delta::num_mini_blocks]; // bitwidth for each mini-block // pointers to shared scratch memory for the warp and block scans/reduces @@ -144,7 +144,7 @@ class delta_binary_packer { void* _bitpack_tmp; // pointer to shared scratch memory used in bitpacking - // write the delta binary header. only call from thread 0 + // Write the delta binary header. Only call from thread 0. inline __device__ void write_header() { delta::put_uleb128(_dst, delta::block_size); @@ -153,7 +153,7 @@ class delta_binary_packer { delta::put_zz128(_dst, _buffer[0]); } - // write the block header. only call from thread 0 + // Write the block header. Only call from thread 0. inline __device__ void write_block_header(zigzag128_t block_min) { delta::put_zz128(_dst, block_min); @@ -161,7 +161,7 @@ class delta_binary_packer { _dst += 4; } - // signed subtraction with defined wrapping behavior + // Signed subtraction with defined wrapping behavior. inline __device__ zigzag128_t subtract(zigzag128_t a, zigzag128_t b) { return static_cast(static_cast(a) - static_cast(b)); @@ -170,7 +170,7 @@ class delta_binary_packer { public: inline __device__ auto num_values() const { return _num_values; } - // initialize the object. only call from thread 0 + // Initialize the object. Only call from thread 0. inline __device__ void init(uint8_t* dest, uint32_t num_values, T* buffer, void* temp_storage) { _dst = dest; @@ -184,10 +184,10 @@ class delta_binary_packer { _values_in_buffer = 0; } - // each thread calls this to add its current value + // Each thread calls this to add its current value. inline __device__ void add_value(T value, bool is_valid) { - // figure out the correct position for the given value + // Figure out the correct position for the given value. size_type const valid = is_valid; size_type pos; size_type num_valid; @@ -210,7 +210,7 @@ class delta_binary_packer { if (_values_in_buffer >= delta::block_size) { flush(); } } - // called by each thread to flush data to the sink. + // Called by each thread to flush data to the sink. inline __device__ uint8_t const* flush() { using cudf::detail::warp_size; @@ -222,23 +222,23 @@ class delta_binary_packer { if (_values_in_buffer <= 0) { return _dst; } - // calculate delta for this thread + // Calculate delta for this thread. size_type const idx = _current_idx + t; zigzag128_t const delta = idx < _num_values ? subtract(_buffer[delta::rolling_idx(idx)], _buffer[delta::rolling_idx(idx - 1)]) : std::numeric_limits::max(); - // find min delta for the block + // Find min delta for the block. auto const min_delta = delta::block_reduce(*_block_tmp).Reduce(delta, cub::Min()); if (t == 0) { block_min = min_delta; } __syncthreads(); - // compute frame of reference for the block + // Compute frame of reference for the block. uleb128_t const norm_delta = idx < _num_values ? subtract(delta, block_min) : 0; - // get max normalized delta for each warp, and use that to determine how many bits to use - // for the bitpacking of this warp + // Get max normalized delta for each warp, and use that to determine how many bits to use + // for the bitpacking of this warp. zigzag128_t const warp_max = delta::warp_reduce(_warp_tmp[warp_id]).Reduce(norm_delta, cub::Max()); __syncwarp(); @@ -250,7 +250,7 @@ class delta_binary_packer { if (t == 0) { write_block_header(block_min); } __syncthreads(); - // now each warp encodes its data...can calculate starting offset with _mb_bits + // Now each warp encodes its data...can calculate starting offset with _mb_bits. // NOTE: using a switch here rather than a loop because the compiler produces code that // uses fewer registers. int cumulative_bits = 0; @@ -275,7 +275,7 @@ class delta_binary_packer { } __syncthreads(); - // last warp updates global delta ptr + // Last warp updates global delta ptr. if (warp_id == delta::num_mini_blocks - 1 && lane_id == 0) { _dst = mb_ptr + _mb_bits[warp_id] * delta::values_per_mini_block / 8; _current_idx = min(warp_idx + delta::values_per_mini_block, _num_values); From c1445edd4ae753263f99ba2aeb1c4a7ec073315c Mon Sep 17 00:00:00 2001 From: seidl Date: Wed, 18 Oct 2023 17:09:47 -0700 Subject: [PATCH 37/37] get rid of TODO --- cpp/src/io/parquet/page_enc.cu | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/cpp/src/io/parquet/page_enc.cu b/cpp/src/io/parquet/page_enc.cu index e0367554936..1e4f061d2e0 100644 --- a/cpp/src/io/parquet/page_enc.cu +++ b/cpp/src/io/parquet/page_enc.cu @@ -215,6 +215,12 @@ void __device__ calculate_frag_size(frag_init_state_s* const s, int t) } } +/** + * @brief Determine the correct page encoding for the given page parameters. + * + * This is only used by the plain and dictionary encoders. Delta encoders will set the page + * encoding directly. + */ Encoding __device__ determine_encoding(PageType page_type, Type physical_type, bool use_dictionary, @@ -226,7 +232,6 @@ Encoding __device__ determine_encoding(PageType page_type, switch (page_type) { case PageType::DATA_PAGE: return use_dictionary ? Encoding::PLAIN_DICTIONARY : Encoding::PLAIN; case PageType::DATA_PAGE_V2: - // TODO need to work in delta encodings here when they're added return physical_type == BOOLEAN ? Encoding::RLE : use_dictionary ? Encoding::RLE_DICTIONARY : Encoding::PLAIN;