From 1581773984acf3548ee90d0a014b5dbd3d517205 Mon Sep 17 00:00:00 2001 From: nvdbaranec <56695930+nvdbaranec@users.noreply.github.com> Date: Mon, 15 May 2023 13:21:00 -0500 Subject: [PATCH] Optimization to decoding of parquet level streams (#13203) An optimization to the decoding of the definition and repetition level streams in Parquet files. Previously, we were decoding these streams using 1 warp. With this optimization we do it arbitrarily wide (currently set for 512 threads). This gives a dramatic improvement. The core of the work is in the new file `rle_stream.cuh` which encapsulates the decoding into an `rle_stream` object. This PR only applies the opimization to the `gpuComputePageSizes` kernel, used for preprocessing list columns and for the chunked read case involving strings or lists. In addition, the `UpdatePageSizes` function has been improved to also work at the block level instead of just using a single warp. Testing with the cudf parquet reader list benchmarks result in as much as a **75%** reduction in time in the `gpuComputePageSizes` kernel. Future PRs will apply this to the gpuDecodePageData kernel. Leaving as a draft for the moment - more detailed benchmarks and numbers forthcoming, along with some possible parameter tuning. Benchmark info. A before/after sample from the `parquet_reader_io_compression` suite on an A5000. The kernel goes from 427 milliseconds to 93 milliseconds. This seems to be a pretty typical situation, although it will definitely be affected by the encoded data (run lengths, etc). ![pq_opt1](https://user-images.githubusercontent.com/56695930/236043918-bcb01c00-d842-46f5-95bd-9579392cda5f.png) The reader benchmarks that involve this kernel yield some great improvements. ``` parquet_read_decode (A = Before. B = After) | data_type | io | cardinality | run_length | bytes_per_second (A) | bytes_per_second (B) |-----------|---------------|-------------|------------|----------------------|---------------------| | LIST | DEVICE_BUFFER | 0 | 1 | 5399068099 | 6044036091 | | LIST | DEVICE_BUFFER | 1000 | 1 | 5930855807 | 6505889742 | | LIST | DEVICE_BUFFER | 0 | 32 | 6862874160 | 7531918407 | | LIST | DEVICE_BUFFER | 1000 | 32 | 6781795229 | 7463856554 | ``` ``` parquet_read_io_compression (A = Before. B = After) io | compression | bytes_per_second (A) | bytes_per_second(B) |---------------|-------------|----------------------|-------------------| | DEVICE_BUFFER | SNAPPY | 307421363 | 393735255 | | DEVICE_BUFFER | SNAPPY | 323998549 | 426045725 | | DEVICE_BUFFER | SNAPPY | 386112997 | 508751604 | | DEVICE_BUFFER | SNAPPY | 381398279 | 498963635 | ``` Authors: - https://github.com/nvdbaranec Approvers: - Yunsong Wang (https://github.com/PointKernel) - Vukasin Milovanovic (https://github.com/vuule) URL: https://github.com/rapidsai/cudf/pull/13203 --- .../cudf/detail/utilities/integer_utils.hpp | 6 +- cpp/src/io/parquet/page_data.cu | 451 +++++++++++------- cpp/src/io/parquet/page_hdr.cu | 10 +- cpp/src/io/parquet/parquet_gpu.hpp | 13 + cpp/src/io/parquet/reader_impl.cpp | 2 +- cpp/src/io/parquet/reader_impl.hpp | 8 + cpp/src/io/parquet/reader_impl_preprocess.cu | 79 ++- cpp/src/io/parquet/rle_stream.cuh | 359 ++++++++++++++ 8 files changed, 725 insertions(+), 203 deletions(-) create mode 100644 cpp/src/io/parquet/rle_stream.cuh diff --git a/cpp/include/cudf/detail/utilities/integer_utils.hpp b/cpp/include/cudf/detail/utilities/integer_utils.hpp index 40faae7e9f4..ccc89b2dce3 100644 --- a/cpp/include/cudf/detail/utilities/integer_utils.hpp +++ b/cpp/include/cudf/detail/utilities/integer_utils.hpp @@ -1,7 +1,7 @@ /* * Copyright 2019 BlazingDB, Inc. * Copyright 2019 Eyal Rozenberg - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -44,7 +44,7 @@ namespace util { * `modulus` is positive. The safety is in regard to rollover. */ template -S round_up_safe(S number_to_round, S modulus) +constexpr S round_up_safe(S number_to_round, S modulus) { auto remainder = number_to_round % modulus; if (remainder == 0) { return number_to_round; } @@ -67,7 +67,7 @@ S round_up_safe(S number_to_round, S modulus) * `modulus` is positive and does not check for overflow. */ template -S round_down_safe(S number_to_round, S modulus) noexcept +constexpr S round_down_safe(S number_to_round, S modulus) noexcept { auto remainder = number_to_round % modulus; auto rounded_down = number_to_round - remainder; diff --git a/cpp/src/io/parquet/page_data.cu b/cpp/src/io/parquet/page_data.cu index 8cb01d5a34b..c00595b7dd6 100644 --- a/cpp/src/io/parquet/page_data.cu +++ b/cpp/src/io/parquet/page_data.cu @@ -15,6 +15,7 @@ */ #include "parquet_gpu.hpp" +#include "rle_stream.cuh" #include #include @@ -46,10 +47,15 @@ namespace gpu { namespace { -constexpr int block_size = 128; -constexpr int non_zero_buffer_size = block_size * 2; - +constexpr int preprocess_block_size = num_rle_stream_decode_threads; // 512 +constexpr int decode_block_size = 128; +constexpr int non_zero_buffer_size = decode_block_size * 2; constexpr int rolling_index(int index) { return index & (non_zero_buffer_size - 1); } +template +constexpr int rolling_lvl_index(int index) +{ + return index % lvl_buf_size; +} struct page_state_s { const uint8_t* data_start; @@ -82,11 +88,11 @@ struct page_state_s { int32_t input_value_count; // how many values of the input we've processed int32_t input_row_count; // how many rows of the input we've processed int32_t input_leaf_count; // how many leaf values of the input we've processed - uint32_t rep[non_zero_buffer_size]; // circular buffer of repetition level values - uint32_t def[non_zero_buffer_size]; // circular buffer of definition level values const uint8_t* lvl_start[NUM_LEVEL_TYPES]; // [def,rep] - int32_t lvl_count[NUM_LEVEL_TYPES]; // how many of each of the streams we've decoded - int32_t row_index_lower_bound; // lower bound of row indices we should process + const uint8_t* abs_lvl_start[NUM_LEVEL_TYPES]; // [def,rep] + const uint8_t* abs_lvl_end[NUM_LEVEL_TYPES]; // [def,rep] + int32_t lvl_count[NUM_LEVEL_TYPES]; // how many of each of the streams we've decoded + int32_t row_index_lower_bound; // lower bound of row indices we should process // a shared-memory cache of frequently used data when decoding. The source of this data is // normally stored in global memory which can yield poor performance. So, when possible @@ -144,32 +150,6 @@ inline __device__ bool is_page_contained(page_state_s* const s, size_t start_row return page_begin >= begin && page_end <= end; } -/** - * @brief Read a 32-bit varint integer - * - * @param[in,out] cur The current data position, updated after the read - * @param[in] end The end data position - * - * @return The 32-bit value read - */ -inline __device__ uint32_t get_vlq32(const uint8_t*& cur, const uint8_t* end) -{ - uint32_t v = *cur++; - if (v >= 0x80 && cur < end) { - v = (v & 0x7f) | ((*cur++) << 7); - if (v >= (0x80 << 7) && cur < end) { - v = (v & ((0x7f << 7) | 0x7f)) | ((*cur++) << 14); - if (v >= (0x80 << 14) && cur < end) { - v = (v & ((0x7f << 14) | (0x7f << 7) | 0x7f)) | ((*cur++) << 21); - if (v >= (0x80 << 21) && cur < end) { - v = (v & ((0x7f << 21) | (0x7f << 14) | (0x7f << 7) | 0x7f)) | ((*cur++) << 28); - } - } - } - } - return v; -} - /** * @brief Parse the beginning of the level section (definition or repetition), * initializes the initial RLE run & value, and returns the section length @@ -178,24 +158,31 @@ inline __device__ uint32_t get_vlq32(const uint8_t*& cur, const uint8_t* end) * @param[in] cur The current data position * @param[in] end The end of the data * @param[in] level_bits The bits required + * @param[in] is_decode_step True if we are performing the decode step. + * @param[in,out] decoders The repetition and definition level stream decoders * * @return The length of the section */ +template __device__ uint32_t InitLevelSection(page_state_s* s, const uint8_t* cur, const uint8_t* end, - level_type lvl) + level_type lvl, + bool is_decode_step, + rle_stream* decoders) { int32_t len; int level_bits = s->col.level_bits[lvl]; Encoding encoding = lvl == level_type::DEFINITION ? s->page.definition_level_encoding : s->page.repetition_level_encoding; + auto start = cur; if (level_bits == 0) { len = 0; s->initial_rle_run[lvl] = s->page.num_input_values * 2; // repeated value s->initial_rle_value[lvl] = 0; s->lvl_start[lvl] = cur; + s->abs_lvl_start[lvl] = cur; } else if (encoding == Encoding::RLE) { // V2 only uses RLE encoding, so only perform check here if (s->page.def_lvl_bytes || s->page.rep_lvl_bytes) { @@ -207,6 +194,7 @@ __device__ uint32_t InitLevelSection(page_state_s* s, len = 0; s->error = 2; } + s->abs_lvl_start[lvl] = cur; if (!s->error) { uint32_t run = get_vlq32(cur, end); s->initial_rle_run[lvl] = run; @@ -220,17 +208,22 @@ __device__ uint32_t InitLevelSection(page_state_s* s, s->initial_rle_value[lvl] = v; } s->lvl_start[lvl] = cur; - if (cur > end) { s->error = 2; } } + + if (cur > end) { s->error = 2; } } else if (encoding == Encoding::BIT_PACKED) { len = (s->page.num_input_values * level_bits + 7) >> 3; s->initial_rle_run[lvl] = ((s->page.num_input_values + 7) >> 3) * 2 + 1; // literal run s->initial_rle_value[lvl] = 0; s->lvl_start[lvl] = cur; + s->abs_lvl_start[lvl] = cur; } else { s->error = 3; len = 0; } + + s->abs_lvl_end[lvl] = start + len; + return static_cast(len); } @@ -242,8 +235,9 @@ __device__ uint32_t InitLevelSection(page_state_s* s, * @param[in] t Warp0 thread ID (0..31) * @param[in] lvl The level type we are decoding - DEFINITION or REPETITION */ +template __device__ void gpuDecodeStream( - uint32_t* output, page_state_s* s, int32_t target_count, int t, level_type lvl) + level_t* output, page_state_s* s, int32_t target_count, int t, level_type lvl) { const uint8_t* cur_def = s->lvl_start[lvl]; const uint8_t* end = s->lvl_end; @@ -980,15 +974,18 @@ static __device__ void gpuOutputGeneric( * @param[in] chunks The global list of chunks * @param[in] min_row Crop all rows below min_row * @param[in] num_rows Maximum number of rows to read - * @param[in] is_decode_step If we are setting up for the decode step (instead of the preprocess - * step) + * @param[in] is_decode_step If we are setting up for the decode step (instead of the preprocess) + * @param[in] decoders rle_stream decoders which will be used for decoding levels. Optional. + * Currently only used by gpuComputePageSizes step) */ +template static __device__ bool setupLocalPageInfo(page_state_s* const s, PageInfo const* p, device_span chunks, size_t min_row, size_t num_rows, - bool is_decode_step) + bool is_decode_step, + rle_stream* decoders = nullptr) { int t = threadIdx.x; int chunk_idx; @@ -1005,7 +1002,7 @@ static __device__ bool setupLocalPageInfo(page_state_s* const s, chunk_idx = s->page.chunk_idx; if (!t) { s->col = chunks[chunk_idx]; } - // if we can use the decode cache, set it up now + // if we can use the nesting decode cache, set it up now auto const can_use_decode_cache = s->page.nesting_info_size <= max_cacheable_nesting_decode_info; if (can_use_decode_cache) { int depth = 0; @@ -1028,6 +1025,7 @@ static __device__ bool setupLocalPageInfo(page_state_s* const s, if (!t) { s->nesting_info = can_use_decode_cache ? s->nesting_decode_cache : s->page.nesting_decode; } + __syncthreads(); // zero counts @@ -1202,9 +1200,9 @@ static __device__ bool setupLocalPageInfo(page_state_s* const s, s->first_output_value = 0; // Find the compressed size of repetition levels - cur += InitLevelSection(s, cur, end, level_type::REPETITION); + cur += InitLevelSection(s, cur, end, level_type::REPETITION, is_decode_step, decoders); // Find the compressed size of definition levels - cur += InitLevelSection(s, cur, end, level_type::DEFINITION); + cur += InitLevelSection(s, cur, end, level_type::DEFINITION, is_decode_step, decoders); s->dict_bits = 0; s->dict_base = nullptr; @@ -1370,14 +1368,19 @@ static __device__ void store_validity(PageNestingDecodeInfo* nesting_info, * @param[out] d The definition level up to which added values are not-null. if t is out of bounds, * d will be -1 * @param[in] s Local page information + * @param[in] rep Repetition level buffer + * @param[in] def Definition level buffer * @param[in] input_value_count The current count of input level values we have processed * @param[in] target_input_value_count The desired # of input level values we want to process * @param[in] t Thread index */ +template inline __device__ void get_nesting_bounds(int& start_depth, int& end_depth, int& d, page_state_s* s, + level_t const* const rep, + level_t const* const def, int input_value_count, int32_t target_input_value_count, int t) @@ -1386,14 +1389,14 @@ inline __device__ void get_nesting_bounds(int& start_depth, end_depth = -1; d = -1; if (input_value_count + t < target_input_value_count) { - int index = rolling_index(input_value_count + t); - d = s->def[index]; + int const index = rolling_lvl_index(input_value_count + t); + d = static_cast(def[index]); // if we have repetition (there are list columns involved) we have to // bound what nesting levels we apply values to if (s->col.max_level[level_type::REPETITION] > 0) { - int r = s->rep[index]; - start_depth = s->nesting_info[r].start_depth; - end_depth = s->nesting_info[d].end_depth; + level_t const r = rep[index]; + start_depth = s->nesting_info[r].start_depth; + end_depth = s->nesting_info[d].end_depth; } // for columns without repetition (even ones involving structs) we always // traverse the entire hierarchy. @@ -1411,11 +1414,16 @@ inline __device__ void get_nesting_bounds(int& start_depth, * @param[in] target_input_value_count The # of repetition/definition levels to process up to * @param[in] s Local page information * @param[out] sb Page state buffer output + * @param[in] rep Repetition level buffer + * @param[in] def Definition level buffer * @param[in] t Thread index */ +template static __device__ void gpuUpdateValidityOffsetsAndRowIndices(int32_t target_input_value_count, page_state_s* s, page_state_buffers_s* sb, + level_t const* const rep, + level_t const* const def, int t) { // max nesting depth of the column @@ -1433,8 +1441,8 @@ static __device__ void gpuUpdateValidityOffsetsAndRowIndices(int32_t target_inpu // determine the nesting bounds for this thread (the range of nesting depths we // will generate new value indices and validity bits for) int start_depth, end_depth, d; - get_nesting_bounds( - start_depth, end_depth, d, s, input_value_count, target_input_value_count, t); + get_nesting_bounds( + start_depth, end_depth, d, s, rep, def, input_value_count, target_input_value_count, t); // 4 interesting things to track: // thread_value_count : # of output values from the view of this thread @@ -1585,11 +1593,16 @@ static __device__ void gpuUpdateValidityOffsetsAndRowIndices(int32_t target_inpu * @param[in] s The local page state * @param[out] sb Page state buffer output * @param[in] target_leaf_count Target count of non-null leaf values to generate indices for + * @param[in] rep Repetition level buffer + * @param[in] def Definition level buffer * @param[in] t Thread index */ +template __device__ void gpuDecodeLevels(page_state_s* s, page_state_buffers_s* sb, int32_t target_leaf_count, + level_t* const rep, + level_t* const def, int t) { bool has_repetition = s->col.max_level[level_type::REPETITION] > 0; @@ -1598,8 +1611,8 @@ __device__ void gpuDecodeLevels(page_state_s* s, int cur_leaf_count = target_leaf_count; while (!s->error && s->nz_count < target_leaf_count && s->input_value_count < s->num_input_values) { - if (has_repetition) { gpuDecodeStream(s->rep, s, cur_leaf_count, t, level_type::REPETITION); } - gpuDecodeStream(s->def, s, cur_leaf_count, t, level_type::DEFINITION); + if (has_repetition) { gpuDecodeStream(rep, s, cur_leaf_count, t, level_type::REPETITION); } + gpuDecodeStream(def, s, cur_leaf_count, t, level_type::DEFINITION); __syncwarp(); // because the rep and def streams are encoded separately, we cannot request an exact @@ -1610,123 +1623,163 @@ __device__ void gpuDecodeLevels(page_state_s* s, : s->lvl_count[level_type::DEFINITION]; // process what we got back - gpuUpdateValidityOffsetsAndRowIndices(actual_leaf_count, s, sb, t); + gpuUpdateValidityOffsetsAndRowIndices( + actual_leaf_count, s, sb, rep, def, t); cur_leaf_count = actual_leaf_count + batch_size; __syncwarp(); } } /** - * @brief Process a batch of incoming repetition/definition level values to generate - * per-nesting level output column size for this page. + * @brief Returns the total size in bytes of string char data in the page. + * + * This function expects the dictionary position to be at 0 and will traverse + * the entire thing. * - * Each page represents one piece of the overall output column. The total output (cudf) - * column sizes are the sum of the values in each individual page. + * Operates on a single warp only. Expects t < 32 * - * @param[in] s The local page info - * @param[in] target_input_value_count The # of repetition/definition levels to process up to - * @param[in] t Thread index - * @param[in] bounds_set Whether or not s->row_index_lower_bound, s->first_row and s->num_rows - * have been computed for this page (they will only be set in the second/trim pass). + * @param s The local page info + * @param t Thread index + */ +__device__ size_type gpuDecodeTotalPageStringSize(page_state_s* s, int t) +{ + size_type target_pos = s->num_input_values; + size_type str_len = 0; + if (s->dict_base) { + auto const [new_target_pos, len] = gpuDecodeDictionaryIndices(s, nullptr, target_pos, t); + target_pos = new_target_pos; + str_len = len; + } else if ((s->col.data_type & 7) == BYTE_ARRAY) { + str_len = gpuInitStringDescriptors(s, nullptr, target_pos, t); + } + if (!t) { *(volatile int32_t*)&s->dict_pos = target_pos; } + return str_len; +} + +/** + * @brief Update output column sizes for every nesting level based on a batch + * of incoming decoded definition and repetition level values. + * + * If bounds_set is true, computes skipped_values and skipped_leaf_values for the + * page to indicate where we need to skip to based on min/max row. + * + * Operates at the block level. + * + * @param s The local page info + * @param target_value_count The target value count to process up to + * @param rep Repetition level buffer + * @param def Definition level buffer + * @param t Thread index + * @param bounds_set A boolean indicating whether or not min/max row bounds have been set */ +template static __device__ void gpuUpdatePageSizes(page_state_s* s, - int32_t target_input_value_count, + int target_value_count, + level_t const* const rep, + level_t const* const def, int t, bool bounds_set) { // max nesting depth of the column int const max_depth = s->col.max_nesting_depth; + + constexpr int num_warps = preprocess_block_size / 32; + constexpr int max_batch_size = num_warps * 32; + + using block_reduce = cub::BlockReduce; + using block_scan = cub::BlockScan; + __shared__ union { + typename block_reduce::TempStorage reduce_storage; + typename block_scan::TempStorage scan_storage; + } temp_storage; + // how many input level values we've processed in the page so far - int input_value_count = s->input_value_count; - // how many leaf values we've processed in the page so far - int input_leaf_count = s->input_leaf_count; + int value_count = s->input_value_count; // how many rows we've processed in the page so far - int input_row_count = s->input_row_count; + int row_count = s->input_row_count; + // how many leaf values we've processed in the page so far + int leaf_count = s->input_leaf_count; + // whether or not we need to continue checking for the first row + bool skipped_values_set = s->page.skipped_values >= 0; - while (input_value_count < target_input_value_count) { - int start_depth, end_depth, d; - get_nesting_bounds( - start_depth, end_depth, d, s, input_value_count, target_input_value_count, t); + while (value_count < target_value_count) { + int const batch_size = min(max_batch_size, target_value_count - value_count); - // count rows and leaf values - int const is_new_row = start_depth == 0 ? 1 : 0; - uint32_t const warp_row_count_mask = ballot(is_new_row); - int const is_new_leaf = (d >= s->nesting_info[max_depth - 1].max_def_level) ? 1 : 0; - uint32_t const warp_leaf_count_mask = ballot(is_new_leaf); + // start/end depth + int start_depth, end_depth, d; + get_nesting_bounds( + start_depth, end_depth, d, s, rep, def, value_count, value_count + batch_size, t); - // is this thread within row bounds? on the first pass we don't know the bounds, so we will be - // computing the full size of the column. on the second pass, we will know our actual row - // bounds, so the computation will cap sizes properly. + // is this thread within row bounds? in the non skip_rows/num_rows case this will always + // be true. int in_row_bounds = 1; + + // if we are in the skip_rows/num_rows case, we need to check against these limits if (bounds_set) { - // absolute row index - int32_t thread_row_index = - input_row_count + ((__popc(warp_row_count_mask & ((1 << t) - 1)) + is_new_row) - 1); - in_row_bounds = thread_row_index >= s->row_index_lower_bound && - thread_row_index < (s->first_row + s->num_rows) - ? 1 - : 0; - - uint32_t const row_bounds_mask = ballot(in_row_bounds); - int const first_thread_in_range = __ffs(row_bounds_mask) - 1; - - // if we've found the beginning of the first row, mark down the position - // in the def/repetition buffer (skipped_values) and the data buffer (skipped_leaf_values) - if (!t && first_thread_in_range >= 0 && s->page.skipped_values < 0) { - // how many values we've skipped in the rep/def levels - s->page.skipped_values = input_value_count + first_thread_in_range; - // how many values we've skipped in the actual data stream - s->page.skipped_leaf_values = - input_leaf_count + __popc(warp_leaf_count_mask & ((1 << first_thread_in_range) - 1)); + // get absolute thread row index + int const is_new_row = start_depth == 0; + int thread_row_count, block_row_count; + block_scan(temp_storage.scan_storage) + .InclusiveSum(is_new_row, thread_row_count, block_row_count); + __syncthreads(); + + // get absolute thread leaf index + int const is_new_leaf = (d >= s->nesting_info[max_depth - 1].max_def_level); + int thread_leaf_count, block_leaf_count; + block_scan(temp_storage.scan_storage) + .InclusiveSum(is_new_leaf, thread_leaf_count, block_leaf_count); + __syncthreads(); + + // if this thread is in row bounds + int const row_index = (thread_row_count + row_count) - 1; + in_row_bounds = + (row_index >= s->row_index_lower_bound) && (row_index < (s->first_row + s->num_rows)); + + // if we have not set skipped values yet, see if we found the first in-bounds row + if (!skipped_values_set) { + int local_count, global_count; + block_scan(temp_storage.scan_storage) + .InclusiveSum(in_row_bounds, local_count, global_count); + __syncthreads(); + + // we found it + if (global_count > 0) { + // this is the thread that represents the first row. + if (local_count == 1 && in_row_bounds) { + s->page.skipped_values = value_count + t; + s->page.skipped_leaf_values = + leaf_count + (is_new_leaf ? thread_leaf_count - 1 : thread_leaf_count); + } + skipped_values_set = true; + } } + + row_count += block_row_count; + leaf_count += block_leaf_count; } // increment value counts across all nesting depths for (int s_idx = 0; s_idx < max_depth; s_idx++) { - PageNestingInfo* pni = &s->page.nesting[s_idx]; - - // if we are within the range of nesting levels we should be adding value indices for - int const in_nesting_bounds = - (s_idx >= start_depth && s_idx <= end_depth && in_row_bounds) ? 1 : 0; - uint32_t const count_mask = ballot(in_nesting_bounds); - if (!t) { pni->batch_size += __popc(count_mask); } + int const in_nesting_bounds = (s_idx >= start_depth && s_idx <= end_depth && in_row_bounds); + int const count = block_reduce(temp_storage.reduce_storage).Sum(in_nesting_bounds); + __syncthreads(); + if (!t) { + PageNestingInfo* pni = &s->page.nesting[s_idx]; + pni->batch_size += count; + } } - input_value_count += min(32, (target_input_value_count - input_value_count)); - input_row_count += __popc(warp_row_count_mask); - input_leaf_count += __popc(warp_leaf_count_mask); + value_count += batch_size; } - // update final page value count + // update final outputs if (!t) { - s->input_value_count = target_input_value_count; - s->input_leaf_count = input_leaf_count; - s->input_row_count = input_row_count; - } -} + s->input_value_count = value_count; -/** - * @brief Returns the total size in bytes of string char data in the page. - * - * This function expects the dictionary position to be at 0 and will traverse - * the entire thing. - * - * @param s The local page info - * @param t Thread index - */ -__device__ size_type gpuDecodeTotalPageStringSize(page_state_s* s, int t) -{ - size_type target_pos = s->num_input_values; - size_type str_len = 0; - if (s->dict_base) { - auto const [new_target_pos, len] = gpuDecodeDictionaryIndices(s, nullptr, target_pos, t); - target_pos = new_target_pos; - str_len = len; - } else if ((s->col.data_type & 7) == BYTE_ARRAY) { - str_len = gpuInitStringDescriptors(s, nullptr, target_pos, t); + // only used in the skip_rows/num_rows case + s->input_leaf_count = leaf_count; + s->input_row_count = row_count; } - if (!t) { *(volatile int32_t*)&s->dict_pos = target_pos; } - return str_len; } /** @@ -1744,7 +1797,8 @@ __device__ size_type gpuDecodeTotalPageStringSize(page_state_s* s, int t) * @param compute_string_sizes Whether or not we should be computing string sizes * (PageInfo::str_bytes) as part of the pass */ -__global__ void __launch_bounds__(block_size) +template +__global__ void __launch_bounds__(preprocess_block_size) gpuComputePageSizes(PageInfo* pages, device_span chunks, size_t min_row, @@ -1759,7 +1813,36 @@ __global__ void __launch_bounds__(block_size) int t = threadIdx.x; PageInfo* pp = &pages[page_idx]; - if (!setupLocalPageInfo(s, pp, chunks, min_row, num_rows, false)) { return; } + // whether or not we have repetition levels (lists) + bool has_repetition = chunks[pp->chunk_idx].max_level[level_type::REPETITION] > 0; + + // the level stream decoders + __shared__ rle_run def_runs[run_buffer_size]; + __shared__ rle_run rep_runs[run_buffer_size]; + rle_stream decoders[level_type::NUM_LEVEL_TYPES] = {{def_runs}, {rep_runs}}; + + // setup page info + if (!setupLocalPageInfo(s, pp, chunks, min_row, num_rows, false, decoders)) { return; } + + // initialize the stream decoders (requires values computed in setupLocalPageInfo) + int const max_batch_size = lvl_buf_size; + level_t* rep = reinterpret_cast(pp->lvl_decode_buf[level_type::REPETITION]); + level_t* def = reinterpret_cast(pp->lvl_decode_buf[level_type::DEFINITION]); + decoders[level_type::DEFINITION].init(s->col.level_bits[level_type::DEFINITION], + s->abs_lvl_start[level_type::DEFINITION], + s->abs_lvl_end[level_type::DEFINITION], + max_batch_size, + def, + s->page.num_input_values); + if (has_repetition) { + decoders[level_type::REPETITION].init(s->col.level_bits[level_type::REPETITION], + s->abs_lvl_start[level_type::REPETITION], + s->abs_lvl_end[level_type::REPETITION], + max_batch_size, + rep, + s->page.num_input_values); + } + __syncthreads(); if (!t) { s->page.skipped_values = -1; @@ -1779,7 +1862,6 @@ __global__ void __launch_bounds__(block_size) // we only need to preprocess hierarchies with repetition in them (ie, hierarchies // containing lists anywhere within). - bool const has_repetition = chunks[pp->chunk_idx].max_level[level_type::REPETITION] > 0; compute_string_sizes = compute_string_sizes && ((s->col.data_type & 7) == BYTE_ARRAY && s->dtype_len != 4); @@ -1829,40 +1911,32 @@ __global__ void __launch_bounds__(block_size) } depth += blockDim.x; } - __syncthreads(); - // optimization : it might be useful to have a version of gpuDecodeStream that could go wider than - // 1 warp. Currently it only uses 1 warp so that it can overlap work with the value decoding step - // when in the actual value decoding kernel. However, during this preprocess step we have no such - // limits - we could go as wide as block_size - if (t < 32) { - constexpr int batch_size = 32; - int target_input_count = batch_size; - while (!s->error && s->input_value_count < s->num_input_values) { - // decode repetition and definition levels. these will attempt to decode at - // least up to the target, but may decode a few more. - if (has_repetition) { - gpuDecodeStream(s->rep, s, target_input_count, t, level_type::REPETITION); - } - gpuDecodeStream(s->def, s, target_input_count, t, level_type::DEFINITION); - __syncwarp(); - - // we may have decoded different amounts from each stream, so only process what we've been - int actual_input_count = has_repetition ? min(s->lvl_count[level_type::REPETITION], - s->lvl_count[level_type::DEFINITION]) - : s->lvl_count[level_type::DEFINITION]; - - // process what we got back - gpuUpdatePageSizes(s, actual_input_count, t, !is_base_pass); - target_input_count = actual_input_count + batch_size; - __syncwarp(); + // the core loop. decode batches of level stream data using rle_stream objects + // and pass the results to gpuUpdatePageSizes + int processed = 0; + while (processed < s->page.num_input_values) { + // TODO: it would not take much more work to make it so that we could run both of these + // decodes concurrently. there are a couple of shared variables internally that would have to + // get dealt with but that's about it. + if (has_repetition) { + decoders[level_type::REPETITION].decode_next(t); + __syncthreads(); } + // the # of rep/def levels will always be the same size + processed += decoders[level_type::DEFINITION].decode_next(t); + __syncthreads(); - // retrieve total string size. - // TODO: investigate if it is possible to do this with a separate warp at the same time levels - // are being decoded above. - if (compute_string_sizes) { s->page.str_bytes = gpuDecodeTotalPageStringSize(s, t); } + // update page sizes + gpuUpdatePageSizes(s, processed, rep, def, t, !is_base_pass); + __syncthreads(); + } + + // retrieve total string size. + // TODO: make this block-based instead of just 1 warp + if (compute_string_sizes) { + if (t < 32) { s->page.str_bytes = gpuDecodeTotalPageStringSize(s, t); } } // update output results: @@ -1925,7 +1999,8 @@ struct null_count_back_copier { * @param min_row Row index to start reading at * @param num_rows Maximum number of rows to read */ -__global__ void __launch_bounds__(block_size) gpuDecodePageData( +template +__global__ void __launch_bounds__(decode_block_size) gpuDecodePageData( PageInfo* pages, device_span chunks, size_t min_row, size_t num_rows) { __shared__ __align__(16) page_state_s state_g; @@ -1938,7 +2013,9 @@ __global__ void __launch_bounds__(block_size) gpuDecodePageData( int out_thread0; [[maybe_unused]] null_count_back_copier _{s, t}; - if (!setupLocalPageInfo(s, &pages[page_idx], chunks, min_row, num_rows, true)) { return; } + if (!setupLocalPageInfo(s, &pages[page_idx], chunks, min_row, num_rows, true)) { + return; + } bool const has_repetition = s->col.max_level[level_type::REPETITION] > 0; @@ -1966,6 +2043,9 @@ __global__ void __launch_bounds__(block_size) gpuDecodePageData( PageNestingDecodeInfo* nesting_info_base = s->nesting_info; + __shared__ level_t rep[non_zero_buffer_size]; // circular buffer of repetition level values + __shared__ level_t def[non_zero_buffer_size]; // circular buffer of definition level values + // skipped_leaf_values will always be 0 for flat hierarchies. uint32_t skipped_leaf_values = s->page.skipped_leaf_values; while (!s->error && (s->input_value_count < s->num_input_values || s->src_pos < s->nz_count)) { @@ -1973,10 +2053,10 @@ __global__ void __launch_bounds__(block_size) gpuDecodePageData( int src_pos = s->src_pos; if (t < out_thread0) { - target_pos = - min(src_pos + 2 * (block_size - out_thread0), s->nz_count + (block_size - out_thread0)); + target_pos = min(src_pos + 2 * (decode_block_size - out_thread0), + s->nz_count + (decode_block_size - out_thread0)); } else { - target_pos = min(s->nz_count, src_pos + block_size - out_thread0); + target_pos = min(s->nz_count, src_pos + decode_block_size - out_thread0); if (out_thread0 > 32) { target_pos = min(target_pos, s->dict_pos); } } __syncthreads(); @@ -1985,7 +2065,7 @@ __global__ void __launch_bounds__(block_size) gpuDecodePageData( // - update validity vectors // - updates offsets (for nested columns) // - produces non-NULL value indices in s->nz_idx for subsequent decoding - gpuDecodeLevels(s, sb, target_pos, t); + gpuDecodeLevels(s, sb, target_pos, rep, def, t); } else if (t < out_thread0) { // skipped_leaf_values will always be 0 for flat hierarchies. uint32_t src_target_pos = target_pos + skipped_leaf_values; @@ -2102,9 +2182,10 @@ void ComputePageSizes(hostdevice_vector& pages, size_t num_rows, bool compute_num_rows, bool compute_string_sizes, + int level_type_size, rmm::cuda_stream_view stream) { - dim3 dim_block(block_size, 1); + dim3 dim_block(preprocess_block_size, 1); dim3 dim_grid(pages.size(), 1); // 1 threadblock per page // computes: @@ -2112,8 +2193,14 @@ void ComputePageSizes(hostdevice_vector& pages, // This computes the size for the entire page, not taking row bounds into account. // If uses_custom_row_bounds is set to true, we have to do a second pass later that "trims" // the starting and ending read values to account for these bounds. - gpuComputePageSizes<<>>( - pages.device_ptr(), chunks, min_row, num_rows, compute_num_rows, compute_string_sizes); + if (level_type_size == 1) { + gpuComputePageSizes<<>>( + pages.device_ptr(), chunks, min_row, num_rows, compute_num_rows, compute_string_sizes); + } else { + gpuComputePageSizes + <<>>( + pages.device_ptr(), chunks, min_row, num_rows, compute_num_rows, compute_string_sizes); + } } /** @@ -2123,15 +2210,21 @@ void __host__ DecodePageData(hostdevice_vector& pages, hostdevice_vector const& chunks, size_t num_rows, size_t min_row, + int level_type_size, rmm::cuda_stream_view stream) { CUDF_EXPECTS(pages.size() > 0, "There is no page to decode"); - dim3 dim_block(block_size, 1); + dim3 dim_block(decode_block_size, 1); dim3 dim_grid(pages.size(), 1); // 1 threadblock per page - gpuDecodePageData<<>>( - pages.device_ptr(), chunks, min_row, num_rows); + if (level_type_size == 1) { + gpuDecodePageData + <<>>(pages.device_ptr(), chunks, min_row, num_rows); + } else { + gpuDecodePageData + <<>>(pages.device_ptr(), chunks, min_row, num_rows); + } } } // namespace gpu diff --git a/cpp/src/io/parquet/page_hdr.cu b/cpp/src/io/parquet/page_hdr.cu index ffb4cb60a20..76af22e068c 100644 --- a/cpp/src/io/parquet/page_hdr.cu +++ b/cpp/src/io/parquet/page_hdr.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -365,9 +365,11 @@ __global__ void __launch_bounds__(128) // this computation is only valid for flat schemas. for nested schemas, // they will be recomputed in the preprocess step by examining repetition and // definition levels - bs->page.chunk_row = 0; - bs->page.num_rows = 0; - bs->page.str_bytes = 0; + bs->page.chunk_row = 0; + bs->page.num_rows = 0; + bs->page.skipped_values = -1; + bs->page.skipped_leaf_values = 0; + bs->page.str_bytes = 0; } num_values = bs->ck.num_values; page_info = bs->ck.page_info; diff --git a/cpp/src/io/parquet/parquet_gpu.hpp b/cpp/src/io/parquet/parquet_gpu.hpp index 4b577929e82..187e5b47fd7 100644 --- a/cpp/src/io/parquet/parquet_gpu.hpp +++ b/cpp/src/io/parquet/parquet_gpu.hpp @@ -45,6 +45,9 @@ constexpr int MAX_DICT_BITS = 24; // Total number of unsigned 24 bit values constexpr size_type MAX_DICT_SIZE = (1 << MAX_DICT_BITS) - 1; +// level decode buffer size. +constexpr int LEVEL_DECODE_BUF_SIZE = 2048; + /** * @brief Struct representing an input column in the file. */ @@ -193,6 +196,9 @@ struct PageInfo { int32_t nesting_info_size; PageNestingInfo* nesting; PageNestingDecodeInfo* nesting_decode; + + // level decode buffers + uint8_t* lvl_decode_buf[level_type::NUM_LEVEL_TYPES]; }; /** @@ -284,6 +290,9 @@ struct file_intermediate_data { hostdevice_vector pages_info{}; hostdevice_vector page_nesting_info{}; hostdevice_vector page_nesting_decode_info{}; + + rmm::device_buffer level_decode_data; + int level_type_size; }; /** @@ -451,6 +460,7 @@ void BuildStringDictionaryIndex(ColumnChunkDesc* chunks, * computed * @param compute_string_sizes If set to true, the str_bytes field in PageInfo will * be computed + * @param level_type_size Size in bytes of the type for level decoding * @param stream CUDA stream to use, default 0 */ void ComputePageSizes(hostdevice_vector& pages, @@ -459,6 +469,7 @@ void ComputePageSizes(hostdevice_vector& pages, size_t num_rows, bool compute_num_rows, bool compute_string_sizes, + int level_type_size, rmm::cuda_stream_view stream); /** @@ -471,12 +482,14 @@ void ComputePageSizes(hostdevice_vector& pages, * @param[in] chunks All chunks to be decoded * @param[in] num_rows Total number of rows to read * @param[in] min_row Minimum number of rows to read + * @param[in] level_type_size Size in bytes of the type for level decoding * @param[in] stream CUDA stream to use, default 0 */ void DecodePageData(hostdevice_vector& pages, hostdevice_vector const& chunks, size_t num_rows, size_t min_row, + int level_type_size, rmm::cuda_stream_view stream); /** diff --git a/cpp/src/io/parquet/reader_impl.cpp b/cpp/src/io/parquet/reader_impl.cpp index 9f1644dfd45..a3e07f9f255 100644 --- a/cpp/src/io/parquet/reader_impl.cpp +++ b/cpp/src/io/parquet/reader_impl.cpp @@ -121,7 +121,7 @@ void reader::impl::decode_page_data(size_t skip_rows, size_t num_rows) chunk_nested_valids.host_to_device(_stream); chunk_nested_data.host_to_device(_stream); - gpu::DecodePageData(pages, chunks, num_rows, skip_rows, _stream); + gpu::DecodePageData(pages, chunks, num_rows, skip_rows, _file_itm_data.level_type_size, _stream); pages.device_to_host(_stream); page_nesting.device_to_host(_stream); diff --git a/cpp/src/io/parquet/reader_impl.hpp b/cpp/src/io/parquet/reader_impl.hpp index 9b40610b141..4d627c41433 100644 --- a/cpp/src/io/parquet/reader_impl.hpp +++ b/cpp/src/io/parquet/reader_impl.hpp @@ -181,6 +181,14 @@ class reader::impl { */ void allocate_nesting_info(); + /** + * @brief Allocate space for use when decoding definition/repetition levels. + * + * One large contiguous buffer of data allocated and + * distributed among the PageInfo structs. + */ + void allocate_level_decode_space(); + /** * @brief Read a chunk of data and return an output table. * diff --git a/cpp/src/io/parquet/reader_impl_preprocess.cu b/cpp/src/io/parquet/reader_impl_preprocess.cu index 14aaec48b2b..4433561ff1b 100644 --- a/cpp/src/io/parquet/reader_impl_preprocess.cu +++ b/cpp/src/io/parquet/reader_impl_preprocess.cu @@ -325,10 +325,11 @@ constexpr bool is_supported_encoding(Encoding enc) * @param chunks List of column chunk descriptors * @param pages List of page information * @param stream CUDA stream used for device memory operations and kernel launches + * @returns The size in bytes of level type data required */ -void decode_page_headers(hostdevice_vector& chunks, - hostdevice_vector& pages, - rmm::cuda_stream_view stream) +int decode_page_headers(hostdevice_vector& chunks, + hostdevice_vector& pages, + rmm::cuda_stream_view stream) { // IMPORTANT : if you change how pages are stored within a chunk (dist pages, then data pages), // please update preprocess_nested_columns to reflect this. @@ -340,6 +341,22 @@ void decode_page_headers(hostdevice_vector& chunks, chunks.host_to_device(stream); gpu::DecodePageHeaders(chunks.device_ptr(), chunks.size(), stream); + + // compute max bytes needed for level data + auto level_bit_size = + cudf::detail::make_counting_transform_iterator(0, [chunks = chunks.begin()] __device__(int i) { + auto c = chunks[i]; + return static_cast( + max(c.level_bits[gpu::level_type::REPETITION], c.level_bits[gpu::level_type::DEFINITION])); + }); + // max level data bit size. + int const max_level_bits = thrust::reduce(rmm::exec_policy(stream), + level_bit_size, + level_bit_size + chunks.size(), + 0, + thrust::maximum()); + auto const level_type_size = std::max(1, cudf::util::div_rounding_up_safe(max_level_bits, 8)); + pages.device_to_host(stream, true); // validate page encodings @@ -347,6 +364,8 @@ void decode_page_headers(hostdevice_vector& chunks, pages.end(), [](auto const& page) { return is_supported_encoding(page.encoding); }), "Unsupported page encoding detected"); + + return level_type_size; } /** @@ -565,9 +584,6 @@ void reader::impl::allocate_nesting_info() page_nesting_decode_info = hostdevice_vector{total_page_nesting_infos, _stream}; - // retrieve from the gpu so we can update - pages.device_to_host(_stream, true); - // update pointers in the PageInfos int target_page_index = 0; int src_info_index = 0; @@ -593,9 +609,6 @@ void reader::impl::allocate_nesting_info() target_page_index += chunks[idx].num_data_pages; } - // copy back to the gpu - pages.host_to_device(_stream); - // fill in int nesting_info_index = 0; std::map, std::vector>> depth_remapping; @@ -673,6 +686,30 @@ void reader::impl::allocate_nesting_info() page_nesting_decode_info.host_to_device(_stream); } +void reader::impl::allocate_level_decode_space() +{ + auto& pages = _file_itm_data.pages_info; + + // TODO: this could be made smaller if we ignored dictionary pages and pages with no + // repetition data. + size_t const per_page_decode_buf_size = + LEVEL_DECODE_BUF_SIZE * 2 * _file_itm_data.level_type_size; + auto const decode_buf_size = per_page_decode_buf_size * pages.size(); + _file_itm_data.level_decode_data = + rmm::device_buffer(decode_buf_size, _stream, rmm::mr::get_current_device_resource()); + + // distribute the buffers + uint8_t* buf = static_cast(_file_itm_data.level_decode_data.data()); + for (size_t idx = 0; idx < pages.size(); idx++) { + auto& p = pages[idx]; + + p.lvl_decode_buf[gpu::level_type::DEFINITION] = buf; + buf += (LEVEL_DECODE_BUF_SIZE * _file_itm_data.level_type_size); + p.lvl_decode_buf[gpu::level_type::REPETITION] = buf; + buf += (LEVEL_DECODE_BUF_SIZE * _file_itm_data.level_type_size); + } +} + std::pair>> reader::impl::create_and_read_column_chunks( cudf::host_span const row_groups_info, size_type num_rows) { @@ -776,7 +813,7 @@ void reader::impl::load_and_decompress_data( auto& raw_page_data = _file_itm_data.raw_page_data; auto& decomp_page_data = _file_itm_data.decomp_page_data; auto& chunks = _file_itm_data.chunks; - auto& pages_info = _file_itm_data.pages_info; + auto& pages = _file_itm_data.pages_info; auto const [has_compressed_data, read_rowgroup_tasks] = create_and_read_column_chunks(row_groups_info, num_rows); @@ -787,13 +824,13 @@ void reader::impl::load_and_decompress_data( // Process dataset chunk pages into output columns auto const total_pages = count_page_headers(chunks, _stream); - pages_info = hostdevice_vector(total_pages, total_pages, _stream); + pages = hostdevice_vector(total_pages, total_pages, _stream); if (total_pages > 0) { // decoding of column/page information - decode_page_headers(chunks, pages_info, _stream); + _file_itm_data.level_type_size = decode_page_headers(chunks, pages, _stream); if (has_compressed_data) { - decomp_page_data = decompress_page_data(chunks, pages_info, _stream); + decomp_page_data = decompress_page_data(chunks, pages, _stream); // Free compressed data for (size_t c = 0; c < chunks.size(); c++) { if (chunks[c].codec != parquet::Compression::UNCOMPRESSED) { raw_page_data[c].reset(); } @@ -815,9 +852,17 @@ void reader::impl::load_and_decompress_data( // create it ourselves. // std::vector output_info = build_output_column_info(); - // nesting information (sizes, etc) stored -per page- - // note : even for flat schemas, we allocate 1 level of "nesting" info - allocate_nesting_info(); + // the following two allocate functions modify the page data + pages.device_to_host(_stream, true); + { + // nesting information (sizes, etc) stored -per page- + // note : even for flat schemas, we allocate 1 level of "nesting" info + allocate_nesting_info(); + + // level decode space + allocate_level_decode_space(); + } + pages.host_to_device(_stream); } } @@ -1575,6 +1620,7 @@ void reader::impl::preprocess_pages(size_t skip_rows, std::numeric_limits::max(), true, // compute num_rows chunk_read_limit > 0, // compute string sizes + _file_itm_data.level_type_size, _stream); // computes: @@ -1626,6 +1672,7 @@ void reader::impl::allocate_columns(size_t skip_rows, size_t num_rows, bool uses num_rows, false, // num_rows is already computed false, // no need to compute string sizes + _file_itm_data.level_type_size, _stream); // print_pages(pages, _stream); diff --git a/cpp/src/io/parquet/rle_stream.cuh b/cpp/src/io/parquet/rle_stream.cuh new file mode 100644 index 00000000000..473db660238 --- /dev/null +++ b/cpp/src/io/parquet/rle_stream.cuh @@ -0,0 +1,359 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "parquet_gpu.hpp" +#include +#include + +namespace cudf::io::parquet::gpu { + +// TODO: consider if these should be template parameters to rle_stream +constexpr int num_rle_stream_decode_threads = 512; +// the -1 here is for the look-ahead warp that fills in the list of runs to be decoded +// in an overlapped manner. so if we had 16 total warps: +// - warp 0 would be filling in batches of runs to be processed +// - warps 1-15 would be decoding the previous batch of runs generated +constexpr int num_rle_stream_decode_warps = + (num_rle_stream_decode_threads / cudf::detail::warp_size) - 1; +constexpr int run_buffer_size = (num_rle_stream_decode_warps * 2); +constexpr int rolling_run_index(int index) { return index % run_buffer_size; } + +/** + * @brief Read a 32-bit varint integer + * + * @param[in,out] cur The current data position, updated after the read + * @param[in] end The end data position + * + * @return The 32-bit value read + */ +inline __device__ uint32_t get_vlq32(uint8_t const*& cur, uint8_t const* end) +{ + uint32_t v = *cur++; + if (v >= 0x80 && cur < end) { + v = (v & 0x7f) | ((*cur++) << 7); + if (v >= (0x80 << 7) && cur < end) { + v = (v & ((0x7f << 7) | 0x7f)) | ((*cur++) << 14); + if (v >= (0x80 << 14) && cur < end) { + v = (v & ((0x7f << 14) | (0x7f << 7) | 0x7f)) | ((*cur++) << 21); + if (v >= (0x80 << 21) && cur < end) { + v = (v & ((0x7f << 21) | (0x7f << 14) | (0x7f << 7) | 0x7f)) | ((*cur++) << 28); + } + } + } + } + return v; +} + +// an individual batch. processed by a warp. +// batches should be in shared memory. +template +struct rle_batch { + uint8_t const* run_start; // start of the run we are part of + int run_offset; // value offset of this batch from the start of the run + level_t* output; + int level_run; + int size; + + __device__ inline void decode(uint8_t const* const end, int level_bits, int lane, int warp_id) + { + int output_pos = 0; + int remain = size; + + // for bitpacked/literal runs, total size is always a multiple of 8. so we need to take care if + // we are not starting/ending exactly on a run boundary + uint8_t const* cur; + if (level_run & 1) { + int const effective_offset = cudf::util::round_down_safe(run_offset, 8); + int const lead_values = (run_offset - effective_offset); + output_pos -= lead_values; + remain += lead_values; + cur = run_start + ((effective_offset >> 3) * level_bits); + } + + // if this is a repeated run, compute the repeated value + int level_val; + if (!(level_run & 1)) { + level_val = run_start[0]; + if (level_bits > 8) { level_val |= run_start[1] << 8; } + } + + // process + while (remain > 0) { + int const batch_len = min(32, remain); + + // if this is a literal run. each thread computes its own level_val + if (level_run & 1) { + int const batch_len8 = (batch_len + 7) >> 3; + if (lane < batch_len) { + int bitpos = lane * level_bits; + uint8_t const* cur_thread = cur + (bitpos >> 3); + bitpos &= 7; + level_val = 0; + if (cur_thread < end) { level_val = cur_thread[0]; } + cur_thread++; + if (level_bits > 8 - bitpos && cur_thread < end) { + level_val |= cur_thread[0] << 8; + cur_thread++; + if (level_bits > 16 - bitpos && cur_thread < end) { level_val |= cur_thread[0] << 16; } + } + level_val = (level_val >> bitpos) & ((1 << level_bits) - 1); + } + + cur += batch_len8 * level_bits; + } + + // store level_val + if (lane < batch_len && (lane + output_pos) >= 0) { output[lane + output_pos] = level_val; } + remain -= batch_len; + output_pos += batch_len; + } + } +}; + +// a single rle run. may be broken up into multiple rle_batches +template +struct rle_run { + int size; // total size of the run + int output_pos; + uint8_t const* start; + int level_run; // level_run header value + int remaining; + + __device__ __inline__ rle_batch next_batch(level_t* const output, int max_size) + { + int const batch_len = min(max_size, remaining); + int const run_offset = size - remaining; + remaining -= batch_len; + return rle_batch{start, run_offset, output, level_run, batch_len}; + } +}; + +// a stream of rle_runs +template +struct rle_stream { + int level_bits; + uint8_t const* start; + uint8_t const* cur; + uint8_t const* end; + + int max_output_values; + int total_values; + int cur_values; + + level_t* output; + + rle_run* runs; + int run_index; + int run_count; + int output_pos; + bool spill; + + int next_batch_run_start; + int next_batch_run_count; + + __device__ rle_stream(rle_run* _runs) : runs(_runs) {} + + __device__ void init(int _level_bits, + uint8_t const* _start, + uint8_t const* _end, + int _max_output_values, + level_t* _output, + int _total_values) + { + level_bits = _level_bits; + start = _start; + cur = _start; + end = _end; + + max_output_values = _max_output_values; + output = _output; + + run_index = 0; + run_count = 0; + output_pos = 0; + spill = false; + next_batch_run_start = 0; + next_batch_run_count = 0; + + total_values = _total_values; + cur_values = 0; + } + + __device__ inline thrust::pair get_run_batch() + { + return {next_batch_run_start, next_batch_run_count}; + } + + // fill in up to num_rle_stream_decode_warps runs or until we reach the max_count limit. + // this function is the critical hotspot. please be very careful altering it. + __device__ inline void fill_run_batch(int max_count) + { + // if we spilled over, we've already got a run at the beginning + next_batch_run_start = spill ? run_index - 1 : run_index; + spill = false; + + // generate runs until we either run out of warps to decode them with, or + // we cross the output limit. + while (run_count < num_rle_stream_decode_warps && output_pos < max_count && cur < end) { + auto& run = runs[rolling_run_index(run_index)]; + + // Encoding::RLE + + // bytes for the varint header + uint8_t const* _cur = cur; + int const level_run = get_vlq32(_cur, end); + int run_bytes = _cur - cur; + + // literal run + if (level_run & 1) { + int const run_size = (level_run >> 1) * 8; + run.size = run_size; + int const run_size8 = (run_size + 7) >> 3; + run_bytes += run_size8 * level_bits; + } + // repeated value run + else { + run.size = (level_run >> 1); + run_bytes++; + // can this ever be > 16? it effectively encodes nesting depth so that would require + // a nesting depth > 64k. + if (level_bits > 8) { run_bytes++; } + } + run.output_pos = output_pos; + run.start = _cur; + run.level_run = level_run; + run.remaining = run.size; + cur += run_bytes; + + output_pos += run.size; + run_count++; + run_index++; + } + + // the above loop computes a batch of runs to be processed. mark down + // the number of runs because the code after this point resets run_count + // for the next batch. each batch is returned via get_next_batch(). + next_batch_run_count = run_count; + + // ------------------------------------- + // prepare for the next run: + + // if we've reached the value output limit on the last run + if (output_pos >= max_count) { + // first, see if we've spilled over + auto const& src = runs[rolling_run_index(run_index - 1)]; + int const spill_count = output_pos - max_count; + + // a spill has occurred in the current run. spill the extra values over into the beginning of + // the next run. + if (spill_count > 0) { + auto& spill_run = runs[rolling_run_index(run_index)]; + spill_run = src; + spill_run.output_pos = 0; + spill_run.remaining = spill_count; + + run_count = 1; + run_index++; + output_pos = spill_run.remaining; + spill = true; + } + // no actual spill needed. just reset the output pos + else { + output_pos = 0; + run_count = 0; + } + } + // didn't cross the limit, so reset the run count + else { + run_count = 0; + } + } + + __device__ inline int decode_next(int t) + { + int const output_count = min(max_output_values, (total_values - cur_values)); + + // special case. if level_bits == 0, just return all zeros. this should tremendously speed up + // a very common case: columns with no nulls, especially if they are non-nested + if (level_bits == 0) { + int written = 0; + while (written < output_count) { + int const batch_size = min(num_rle_stream_decode_threads, output_count - written); + if (t < batch_size) { output[written + t] = 0; } + written += batch_size; + } + cur_values += output_count; + return output_count; + } + + // otherwise, full decode. + int const warp_id = t / cudf::detail::warp_size; + int const warp_decode_id = warp_id - 1; + int const warp_lane = t % cudf::detail::warp_size; + + __shared__ int run_start; + __shared__ int num_runs; + __shared__ int values_processed; + if (!t) { + // carryover from the last call. + thrust::tie(run_start, num_runs) = get_run_batch(); + values_processed = 0; + } + __syncthreads(); + + do { + // warp 0 reads ahead and generates batches of runs to be decoded by remaining warps. + if (!warp_id) { + // fill the next set of runs. fill_runs will generally be the bottleneck for any + // kernel that uses an rle_stream. + if (warp_lane == 0) { fill_run_batch(output_count); } + } + // remaining warps decode the runs + else if (warp_decode_id < num_runs) { + // each warp handles 1 run, regardless of size. + // TODO: having each warp handle exactly 32 values would be ideal. as an example, the + // repetition levels for one of the list benchmarks decodes in ~3ms total, while the + // definition levels take ~11ms - the difference is entirely due to long runs in the + // definition levels. + auto& run = runs[rolling_run_index(run_start + warp_decode_id)]; + auto batch = run.next_batch(output + run.output_pos, + min(run.remaining, (output_count - run.output_pos))); + batch.decode(end, level_bits, warp_lane, warp_decode_id); + // last warp updates total values processed + if (warp_lane == 0 && warp_decode_id == num_runs - 1) { + values_processed = run.output_pos + batch.size; + } + } + __syncthreads(); + + // if we haven't run out of space, retrieve the next batch. otherwise leave it for the next + // call. + if (!t && values_processed < output_count) { + thrust::tie(run_start, num_runs) = get_run_batch(); + } + __syncthreads(); + } while (num_runs > 0 && values_processed < output_count); + + cur_values += values_processed; + + // valid for every thread + return values_processed; + } +}; + +} // namespace cudf::io::parquet::gpu