From 7940b5bb985e0db805f7effb0eb3de53e7f57b88 Mon Sep 17 00:00:00 2001 From: Vukasin Milovanovic Date: Fri, 21 Oct 2022 11:00:32 -0700 Subject: [PATCH] Fix maximum page size estimate in Parquet writer (#11962) Closes https://github.com/rapidsai/cudf/issues/11916 cuda memcheck reports an OOB write in one of the tests. The root cause is an underallocated buffer for encoded pages. This PR fixes the computation of the maximum size of data pages (RLE encoded) when dictionary encoding is used. Other changes: Refactored max RLE page size computation to avoid code repetition. Use actual dictionary index width instead of (outdated) worst case. Authors: - Vukasin Milovanovic (https://github.com/vuule) Approvers: - David Wendt (https://github.com/davidwendt) - Bradley Dice (https://github.com/bdice) URL: https://github.com/rapidsai/cudf/pull/11962 --- cpp/src/io/parquet/page_enc.cu | 38 ++++++++++++++++------------------ 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/cpp/src/io/parquet/page_enc.cu b/cpp/src/io/parquet/page_enc.cu index 7c5651b1ef8..15bd4fe17e3 100644 --- a/cpp/src/io/parquet/page_enc.cu +++ b/cpp/src/io/parquet/page_enc.cu @@ -228,6 +228,14 @@ __global__ void __launch_bounds__(128) } } +constexpr uint32_t max_RLE_page_size(uint8_t value_bit_width, uint32_t num_values) +{ + if (value_bit_width == 0) return 0; + + // Run length = 4, max(rle/bitpack header) = 5, add one byte per 256 values for overhead + return 4 + 5 + util::div_rounding_up_unsafe(num_values * value_bit_width, 8) + (num_values / 256); +} + // blockDim {128,1,1} __global__ void __launch_bounds__(128) gpuInitPages(device_2dspan chunks, @@ -340,7 +348,7 @@ __global__ void __launch_bounds__(128) __syncwarp(); uint32_t fragment_data_size = (ck_g.use_dictionary) - ? frag_g.num_leaf_values * 2 // Assume worst-case of 2-bytes per dictionary index + ? frag_g.num_leaf_values * util::div_rounding_up_unsafe(ck_g.dict_rle_bits, 8) : frag_g.fragment_data_size; // TODO (dm): this convoluted logic to limit page size needs refactoring size_t this_max_page_size = (values_in_page * 2 >= ck_g.num_values) ? 256 * 1024 @@ -354,8 +362,8 @@ __global__ void __launch_bounds__(128) (values_in_page > 0 && (page_size + fragment_data_size > this_max_page_size)) || rows_in_page >= max_page_size_rows) { if (ck_g.use_dictionary) { - page_size = - 1 + 5 + ((values_in_page * ck_g.dict_rle_bits + 7) >> 3) + (values_in_page >> 8); + // Additional byte to store entry bit width + page_size = 1 + max_RLE_page_size(ck_g.dict_rle_bits, values_in_page); } if (!t) { page_g.num_fragments = fragments_in_chunk - page_start; @@ -378,23 +386,13 @@ __global__ void __launch_bounds__(128) if (not comp_page_sizes.empty()) { page_g.compressed_data = ck_g.compressed_bfr + comp_page_offset; } - page_g.start_row = cur_row; - page_g.num_rows = rows_in_page; - page_g.num_leaf_values = leaf_values_in_page; - page_g.num_values = values_in_page; - uint32_t def_level_bits = col_g.num_def_level_bits(); - uint32_t rep_level_bits = col_g.num_rep_level_bits(); - // Run length = 4, max(rle/bitpack header) = 5, add one byte per 256 values for overhead - // TODO (dm): Improve readability of these calculations. - uint32_t def_level_size = - (def_level_bits != 0) - ? 4 + 5 + ((def_level_bits * page_g.num_values + 7) >> 3) + (page_g.num_values >> 8) - : 0; - uint32_t rep_level_size = - (rep_level_bits != 0) - ? 4 + 5 + ((rep_level_bits * page_g.num_values + 7) >> 3) + (page_g.num_values >> 8) - : 0; - page_g.max_data_size = page_size + def_level_size + rep_level_size; + page_g.start_row = cur_row; + page_g.num_rows = rows_in_page; + page_g.num_leaf_values = leaf_values_in_page; + page_g.num_values = values_in_page; + auto const def_level_size = max_RLE_page_size(col_g.num_def_level_bits(), values_in_page); + auto const rep_level_size = max_RLE_page_size(col_g.num_rep_level_bits(), values_in_page); + page_g.max_data_size = page_size + def_level_size + rep_level_size; pagestats_g.start_chunk = ck_g.first_fragment + page_start; pagestats_g.num_chunks = page_g.num_fragments;