diff --git a/cpp/src/io/parquet/delta_binary.cuh b/cpp/src/io/parquet/delta_binary.cuh index 2f513df4b5a..1fa05b3a6c2 100644 --- a/cpp/src/io/parquet/delta_binary.cuh +++ b/cpp/src/io/parquet/delta_binary.cuh @@ -101,7 +101,7 @@ struct delta_binary_decoder { uint8_t const* cur_mb_start; // pointer to the start of the current mini-block data uint8_t const* cur_bitwidths; // pointer to the bitwidth array in the block - uleb128_t value[delta_rolling_buf_size]; // circular buffer of delta values + zigzag128_t value[delta_rolling_buf_size]; // circular buffer of delta values // returns the value stored in the `value` array at index // `rolling_index(idx)`. If `idx` is `0`, then return `first_value`. @@ -299,6 +299,49 @@ struct delta_binary_decoder { } } + // Decodes and skips values until the block containing the value after `skip` is reached. + // Keeps a running sum of the values and returns that upon exit. Called by all threads in a + // warp 0. Result is only valid on thread 0. + // This is intended for use only by the DELTA_LENGTH_BYTE_ARRAY decoder. + inline __device__ size_t skip_values_and_sum(int skip) + { + using cudf::detail::warp_size; + // DELTA_LENGTH_BYTE_ARRAY lengths are encoded as INT32 by convention (since the PLAIN encoding + // uses 4-byte lengths). + using delta_length_type = int32_t; + using warp_reduce = cub::WarpReduce; + __shared__ warp_reduce::TempStorage temp_storage; + int const t = threadIdx.x; + + // initialize sum with first value, which is stored in the block header. cast to + // `delta_length_type` to ensure the value is interpreted properly before promoting it + // back to `size_t`. + size_t sum = static_cast(value_at(0)); + + // if only skipping one value, we're done already + if (skip == 1) { return sum; } + + // need to do in multiple passes if values_per_mb != 32 + uint32_t const num_pass = values_per_mb / warp_size; + + while (current_value_idx < skip && current_value_idx < num_encoded_values(true)) { + calc_mini_block_values(t); + + int const idx = current_value_idx + t; + + for (uint32_t p = 0; p < num_pass; p++) { + auto const pidx = idx + p * warp_size; + size_t const val = pidx < skip ? static_cast(value_at(pidx)) : 0; + auto const warp_sum = warp_reduce(temp_storage).Sum(val); + if (t == 0) { sum += warp_sum; } + } + if (t == 0) { setup_next_mini_block(true); } + __syncwarp(); + } + + return sum; + } + // decodes the current mini block and stores the values obtained. should only be called by // a single warp. inline __device__ void decode_batch() diff --git a/cpp/src/io/parquet/delta_enc.cuh b/cpp/src/io/parquet/delta_enc.cuh index cbb44d30a56..f90d364f5eb 100644 --- a/cpp/src/io/parquet/delta_enc.cuh +++ b/cpp/src/io/parquet/delta_enc.cuh @@ -186,6 +186,7 @@ class delta_binary_packer { _bitpack_tmp = _buffer + delta::buffer_size; _current_idx = 0; _values_in_buffer = 0; + _buffer[0] = 0; } // Each thread calls this to add its current value. @@ -215,7 +216,7 @@ class delta_binary_packer { } // Called by each thread to flush data to the sink. - inline __device__ uint8_t const* flush() + inline __device__ uint8_t* flush() { using cudf::detail::warp_size; __shared__ T block_min; @@ -224,6 +225,10 @@ class delta_binary_packer { int const warp_id = t / warp_size; int const lane_id = t % warp_size; + // if no values have been written, still need to write the header + if (t == 0 && _current_idx == 0) { write_header(); } + + // if there are no values to write, just return if (_values_in_buffer <= 0) { return _dst; } // Calculate delta for this thread. diff --git a/cpp/src/io/parquet/page_decode.cuh b/cpp/src/io/parquet/page_decode.cuh index f6f2f9e9f18..a378ea33a0c 100644 --- a/cpp/src/io/parquet/page_decode.cuh +++ b/cpp/src/io/parquet/page_decode.cuh @@ -1332,6 +1332,7 @@ inline __device__ bool setupLocalPageInfo(page_state_s* const s, s->dict_run = 0; } break; case Encoding::DELTA_BINARY_PACKED: + case Encoding::DELTA_LENGTH_BYTE_ARRAY: case Encoding::DELTA_BYTE_ARRAY: // nothing to do, just don't error break; diff --git a/cpp/src/io/parquet/page_delta_decode.cu b/cpp/src/io/parquet/page_delta_decode.cu index 98f8fbb09a2..44ec0e1e027 100644 --- a/cpp/src/io/parquet/page_delta_decode.cu +++ b/cpp/src/io/parquet/page_delta_decode.cu @@ -18,6 +18,8 @@ #include "page_string_utils.cuh" #include "parquet_gpu.hpp" +#include + #include #include @@ -463,7 +465,7 @@ __global__ void __launch_bounds__(decode_block_size) bool const has_repetition = s->col.max_level[level_type::REPETITION] > 0; // choose a character parallel string copy when the average string is longer than a warp - auto const use_char_ll = (s->page.str_bytes / s->page.num_valids) > cudf::detail::warp_size; + auto const use_char_ll = (s->page.str_bytes / s->page.num_valids) > warp_size; // copying logic from gpuDecodePageData. PageNestingDecodeInfo const* nesting_info_base = s->nesting_info; @@ -493,6 +495,7 @@ __global__ void __launch_bounds__(decode_block_size) int const leaf_level_index = s->col.max_nesting_depth - 1; auto strings_data = nesting_info_base[leaf_level_index].string_out; + // sanity check to make sure we can process this page auto const batch_size = prefix_db->values_per_mb; if (batch_size > max_delta_mini_block_size) { set_error(static_cast(decode_error::DELTA_PARAMS_UNSUPPORTED), @@ -581,18 +584,174 @@ __global__ void __launch_bounds__(decode_block_size) if (t == 0 and s->error != 0) { set_error(s->error, error_code); } } +// Decode page data that is DELTA_LENGTH_BYTE_ARRAY packed. This encoding consists of a +// DELTA_BINARY_PACKED array of string lengths, followed by the string data. +template +__global__ void __launch_bounds__(decode_block_size) + gpuDecodeDeltaLengthByteArray(PageInfo* pages, + device_span chunks, + size_t min_row, + size_t num_rows, + kernel_error::pointer error_code) +{ + using cudf::detail::warp_size; + __shared__ __align__(16) delta_binary_decoder db_state; + __shared__ __align__(16) page_state_s state_g; + __shared__ __align__(16) page_state_buffers_s state_buffers; + __shared__ __align__(8) uint8_t const* page_string_data; + __shared__ size_t string_offset; + + page_state_s* const s = &state_g; + auto* const sb = &state_buffers; + int const page_idx = blockIdx.x; + int const t = threadIdx.x; + int const lane_id = t % warp_size; + auto* const db = &db_state; + [[maybe_unused]] null_count_back_copier _{s, t}; + + auto const mask = decode_kernel_mask::DELTA_LENGTH_BA; + if (!setupLocalPageInfo(s, + &pages[page_idx], + chunks, + min_row, + num_rows, + mask_filter{mask}, + page_processing_stage::DECODE)) { + return; + } + + bool const has_repetition = s->col.max_level[level_type::REPETITION] > 0; + + // copying logic from gpuDecodePageData. + PageNestingDecodeInfo const* nesting_info_base = s->nesting_info; + + __shared__ level_t rep[delta_rolling_buf_size]; // circular buffer of repetition level values + __shared__ level_t def[delta_rolling_buf_size]; // circular buffer of definition level values + + // skipped_leaf_values will always be 0 for flat hierarchies. + uint32_t const skipped_leaf_values = s->page.skipped_leaf_values; + + // initialize delta state + if (t == 0) { + string_offset = 0; + page_string_data = db->find_end_of_block(s->data_start, s->data_end); + } + __syncthreads(); + + int const leaf_level_index = s->col.max_nesting_depth - 1; + + // sanity check to make sure we can process this page + auto const batch_size = db->values_per_mb; + if (batch_size > max_delta_mini_block_size) { + set_error(static_cast(decode_error::DELTA_PARAMS_UNSUPPORTED), error_code); + return; + } + + // if this is a bounds page, then we need to decode up to the first mini-block + // that has a value we need, and set string_offset to the position of the first value in the + // string data block. + auto const is_bounds_pg = is_bounds_page(s, min_row, num_rows, has_repetition); + if (is_bounds_pg && s->page.start_val > 0) { + if (t < warp_size) { + // string_off is only valid on thread 0 + auto const string_off = db->skip_values_and_sum(s->page.start_val); + if (t == 0) { + string_offset = string_off; + + // if there is no repetition, then we need to work through the whole page, so reset the + // delta decoder to the beginning of the page + if (not has_repetition) { db->init_binary_block(s->data_start, s->data_end); } + } + } + __syncthreads(); + } + + int string_pos = has_repetition ? s->page.start_val : 0; + + while (!s->error && (s->input_value_count < s->num_input_values || s->src_pos < s->nz_count)) { + uint32_t target_pos; + uint32_t const src_pos = s->src_pos; + + if (t < 2 * warp_size) { // warp0..1 + target_pos = min(src_pos + 2 * batch_size, s->nz_count + batch_size); + } else { // warp2 + target_pos = min(s->nz_count, src_pos + batch_size); + } + // this needs to be here to prevent warp 2 modifying src_pos before all threads have read it + __syncthreads(); + + // warp0 will decode the rep/def levels, warp1 will unpack a mini-batch of deltas. + // warp2 waits one cycle for warps 0/1 to produce a batch, and then stuffs string sizes + // into the proper location in the output. warp 3 does nothing until it's time to copy + // string data. + if (t < warp_size) { + // warp 0 + // decode repetition and definition levels. + // - update validity vectors + // - updates offsets (for nested columns) + // - produces non-NULL value indices in s->nz_idx for subsequent decoding + gpuDecodeLevels(s, sb, target_pos, rep, def, t); + } else if (t < 2 * warp_size) { + // warp 1 + db->decode_batch(); + + } else if (t < 3 * warp_size && src_pos < target_pos) { + // warp 2 + int const nproc = min(batch_size, s->page.end_val - string_pos); + string_pos += nproc; + + // process the mini-block in batches of 32 + for (uint32_t sp = src_pos + lane_id; sp < src_pos + batch_size; sp += 32) { + // the position in the output column/buffer + int dst_pos = sb->nz_idx[rolling_index(sp)]; + + // handle skip_rows here. flat hierarchies can just skip up to first_row. + if (!has_repetition) { dst_pos -= s->first_row; } + + // fill in offsets array + if (dst_pos >= 0 && sp < target_pos) { + auto const offptr = + reinterpret_cast(nesting_info_base[leaf_level_index].data_out) + dst_pos; + *offptr = db->value_at(sp + skipped_leaf_values); + } + __syncwarp(); + } + + if (lane_id == 0) { s->src_pos = src_pos + batch_size; } + } + __syncthreads(); + } + + // now turn array of lengths into offsets + int value_count = nesting_info_base[leaf_level_index].value_count; + + // if no repetition we haven't calculated start/end bounds and instead just skipped + // values until we reach first_row. account for that here. + if (!has_repetition) { value_count -= s->first_row; } + + auto const offptr = reinterpret_cast(nesting_info_base[leaf_level_index].data_out); + block_excl_sum(offptr, value_count, s->page.str_offset); + + // finally, copy the string data into place + auto const dst = nesting_info_base[leaf_level_index].string_out; + auto const src = page_string_data + string_offset; + memcpy_block(dst, src, s->page.str_bytes, t); + + if (t == 0 and s->error != 0) { set_error(s->error, error_code); } +} + } // anonymous namespace /** * @copydoc cudf::io::parquet::detail::DecodeDeltaBinary */ -void __host__ DecodeDeltaBinary(cudf::detail::hostdevice_vector& pages, - cudf::detail::hostdevice_vector const& chunks, - size_t num_rows, - size_t min_row, - int level_type_size, - kernel_error::pointer error_code, - rmm::cuda_stream_view stream) +void DecodeDeltaBinary(cudf::detail::hostdevice_vector& pages, + cudf::detail::hostdevice_vector const& chunks, + size_t num_rows, + size_t min_row, + int level_type_size, + kernel_error::pointer error_code, + rmm::cuda_stream_view stream) { CUDF_EXPECTS(pages.size() > 0, "There is no page to decode"); @@ -611,13 +770,13 @@ void __host__ DecodeDeltaBinary(cudf::detail::hostdevice_vector& pages /** * @copydoc cudf::io::parquet::gpu::DecodeDeltaByteArray */ -void __host__ DecodeDeltaByteArray(cudf::detail::hostdevice_vector& pages, - cudf::detail::hostdevice_vector const& chunks, - size_t num_rows, - size_t min_row, - int level_type_size, - kernel_error::pointer error_code, - rmm::cuda_stream_view stream) +void DecodeDeltaByteArray(cudf::detail::hostdevice_vector& pages, + cudf::detail::hostdevice_vector const& chunks, + size_t num_rows, + size_t min_row, + int level_type_size, + kernel_error::pointer error_code, + rmm::cuda_stream_view stream) { CUDF_EXPECTS(pages.size() > 0, "There is no page to decode"); @@ -633,4 +792,29 @@ void __host__ DecodeDeltaByteArray(cudf::detail::hostdevice_vector& pa } } +/** + * @copydoc cudf::io::parquet::gpu::DecodeDeltaByteArray + */ +void DecodeDeltaLengthByteArray(cudf::detail::hostdevice_vector& pages, + cudf::detail::hostdevice_vector const& chunks, + size_t num_rows, + size_t min_row, + int level_type_size, + kernel_error::pointer error_code, + rmm::cuda_stream_view stream) +{ + CUDF_EXPECTS(pages.size() > 0, "There is no page to decode"); + + dim3 const dim_block(decode_block_size, 1); + dim3 const dim_grid(pages.size(), 1); // 1 threadblock per page + + if (level_type_size == 1) { + gpuDecodeDeltaLengthByteArray<<>>( + pages.device_ptr(), chunks, min_row, num_rows, error_code); + } else { + gpuDecodeDeltaLengthByteArray<<>>( + pages.device_ptr(), chunks, min_row, num_rows, error_code); + } +} + } // namespace cudf::io::parquet::detail diff --git a/cpp/src/io/parquet/page_enc.cu b/cpp/src/io/parquet/page_enc.cu index 7b0eabdcfd4..8e1c0682ffd 100644 --- a/cpp/src/io/parquet/page_enc.cu +++ b/cpp/src/io/parquet/page_enc.cu @@ -468,7 +468,10 @@ __global__ void __launch_bounds__(128) } } -__device__ size_t delta_data_len(Type physical_type, cudf::type_id type_id, uint32_t num_values) +__device__ size_t delta_data_len(Type physical_type, + cudf::type_id type_id, + uint32_t num_values, + size_t page_size) { auto const dtype_len_out = physical_type_len(physical_type, type_id); auto const dtype_len = [&]() -> uint32_t { @@ -495,7 +498,15 @@ __device__ size_t delta_data_len(Type physical_type, cudf::type_id type_id, uint // modified. auto const header_size = 2 + 1 + 5 + dtype_len + 1; - return header_size + num_blocks * block_size; + // The above is just a size estimate for a DELTA_BINARY_PACKED data page. For BYTE_ARRAY + // data we also need to add size of the char data. `page_size` that is passed in is the + // plain encoded size (i.e. num_values * sizeof(size_type) + char_data_len), so the char + // data len is `page_size` minus the first term. + // TODO: this will need to change for DELTA_BYTE_ARRAY encoding + auto const char_data_len = + physical_type == BYTE_ARRAY ? page_size - num_values * sizeof(size_type) : 0; + + return header_size + num_blocks * block_size + char_data_len; } // blockDim {128,1,1} @@ -535,7 +546,8 @@ __global__ void __launch_bounds__(128) auto const physical_type = col_g.physical_type; auto const type_id = col_g.leaf_column->type().id(); auto const is_use_delta = - write_v2_headers && !ck_g.use_dictionary && (physical_type == INT32 || physical_type == INT64); + write_v2_headers && !ck_g.use_dictionary && + (physical_type == INT32 || physical_type == INT64 || physical_type == BYTE_ARRAY); if (t < 32) { uint32_t fragments_in_chunk = 0; @@ -696,8 +708,9 @@ __global__ void __launch_bounds__(128) auto const rep_level_size = max_RLE_page_size(col_g.num_rep_level_bits(), values_in_page); // get a different bound if using delta encoding if (is_use_delta) { - page_size = - max(page_size, delta_data_len(physical_type, type_id, page_g.num_leaf_values)); + auto const delta_len = + delta_data_len(physical_type, type_id, page_g.num_leaf_values, page_size); + page_size = max(page_size, delta_len); } auto const max_data_size = page_size + def_level_size + rep_level_size + rle_pad; // page size must fit in 32-bit signed integer @@ -728,7 +741,12 @@ __global__ void __launch_bounds__(128) if (not pages.empty()) { // set encoding if (is_use_delta) { - page_g.kernel_mask = encode_kernel_mask::DELTA_BINARY; + // TODO(ets): at some point make a more intelligent decision on this. DELTA_LENGTH_BA + // should always be preferred over PLAIN, but DELTA_BINARY is a different matter. + // If the delta encoding size is going to be close to 32 bits anyway, then plain + // is a better choice. + page_g.kernel_mask = physical_type == BYTE_ARRAY ? encode_kernel_mask::DELTA_LENGTH_BA + : encode_kernel_mask::DELTA_BINARY; } else if (ck_g.use_dictionary || physical_type == BOOLEAN) { page_g.kernel_mask = encode_kernel_mask::DICTIONARY; } else { @@ -1307,7 +1325,7 @@ __global__ void __launch_bounds__(block_size, 8) gpuEncodePageLevels(device_span __syncthreads(); // if max_def <= 1, then the histogram is trivial to calculate if (s->page.def_histogram != nullptr and s->col.max_def_level > 1) { - // Only calculate up to max_def_level...the last entry is valid_count and will be filled + // Only calculate up to max_def_level...the last entry is num_valid and will be filled // in later. generate_def_level_histogram( s->page.def_histogram, s, nrows, rle_numvals, s->col.max_def_level); @@ -1755,7 +1773,6 @@ __global__ void __launch_bounds__(block_size, 8) if (BitAnd(s->page.kernel_mask, encode_kernel_mask::DICTIONARY) == 0) { return; } // Encode data values - __syncthreads(); auto const physical_type = s->col.physical_type; auto const type_id = s->col.leaf_column->type().id(); auto const dtype_len_out = physical_type_len(physical_type, type_id); @@ -1888,7 +1905,6 @@ __global__ void __launch_bounds__(block_size, 8) if (BitAnd(s->page.kernel_mask, encode_kernel_mask::DELTA_BINARY) == 0) { return; } // Encode data values - __syncthreads(); auto const physical_type = s->col.physical_type; auto const type_id = s->col.leaf_column->type().id(); auto const dtype_len_out = physical_type_len(physical_type, type_id); @@ -1956,6 +1972,134 @@ __global__ void __launch_bounds__(block_size, 8) finish_page_encode(s, delta_ptr, pages, comp_in, comp_out, comp_results, true); } +// DELTA_LENGTH_BYTE_ARRAY page data encoder +// blockDim(128, 1, 1) +template +__global__ void __launch_bounds__(block_size, 8) + gpuEncodeDeltaLengthByteArrayPages(device_span pages, + device_span> comp_in, + device_span> comp_out, + device_span comp_results) +{ + // block of shared memory for value storage and bit packing + __shared__ uleb128_t delta_shared[delta::buffer_size + delta::block_size]; + __shared__ __align__(8) page_enc_state_s<0> state_g; + __shared__ delta_binary_packer packer; + __shared__ uint8_t const* first_string; + __shared__ size_type string_data_len; + using block_reduce = cub::BlockReduce; + __shared__ union { + typename block_reduce::TempStorage reduce_storage; + typename delta_binary_packer::index_scan::TempStorage delta_index_tmp; + typename delta_binary_packer::block_reduce::TempStorage delta_reduce_tmp; + typename delta_binary_packer::warp_reduce::TempStorage + delta_warp_red_tmp[delta::num_mini_blocks]; + } temp_storage; + + auto* const s = &state_g; + uint32_t t = threadIdx.x; + + if (t == 0) { + state_g = page_enc_state_s<0>{}; + s->page = pages[blockIdx.x]; + s->ck = *s->page.chunk; + s->col = *s->ck.col_desc; + s->rle_len_pos = nullptr; + // get s->cur back to where it was at the end of encoding the rep and def level data + s->cur = + s->page.page_data + s->page.max_hdr_size + s->page.def_lvl_bytes + s->page.rep_lvl_bytes; + } + __syncthreads(); + + if (BitAnd(s->page.kernel_mask, encode_kernel_mask::DELTA_LENGTH_BA) == 0) { return; } + + // Encode data values + if (t == 0) { + uint8_t* dst = s->cur; + s->rle_run = 0; + s->rle_pos = 0; + s->rle_numvals = 0; + s->rle_out = dst; + s->page.encoding = Encoding::DELTA_LENGTH_BYTE_ARRAY; + s->page_start_val = row_to_value_idx(s->page.start_row, s->col); + s->chunk_start_val = row_to_value_idx(s->ck.start_row, s->col); + } + __syncthreads(); + + auto const type_id = s->col.leaf_column->type().id(); + + // encode the lengths as DELTA_BINARY_PACKED + if (t == 0) { + first_string = nullptr; + packer.init(s->cur, s->page.num_valid, reinterpret_cast(delta_shared), &temp_storage); + + // if there are valid values, find a pointer to the first valid string + if (s->page.num_valid != 0) { + for (uint32_t idx = 0; idx < s->page.num_leaf_values; idx++) { + size_type const idx_in_col = s->page_start_val + idx; + if (s->col.leaf_column->is_valid(idx_in_col)) { + if (type_id == type_id::STRING) { + first_string = reinterpret_cast( + s->col.leaf_column->element(idx_in_col).data()); + } else if (s->col.output_as_byte_array && type_id == type_id::LIST) { + first_string = reinterpret_cast( + get_element(*s->col.leaf_column, idx_in_col).data()); + } + break; + } + } + } + } + __syncthreads(); + + uint32_t len = 0; + for (uint32_t cur_val_idx = 0; cur_val_idx < s->page.num_leaf_values;) { + uint32_t const nvals = min(s->page.num_leaf_values - cur_val_idx, delta::block_size); + + size_type const val_idx_in_block = cur_val_idx + t; + size_type const val_idx = s->page_start_val + val_idx_in_block; + + bool const is_valid = + (val_idx < s->col.leaf_column->size() && val_idx_in_block < s->page.num_leaf_values) + ? s->col.leaf_column->is_valid(val_idx) + : false; + + cur_val_idx += nvals; + + int32_t v = 0; + if (is_valid) { + if (type_id == type_id::STRING) { + v = s->col.leaf_column->element(val_idx).size_bytes(); + } else if (s->col.output_as_byte_array && type_id == type_id::LIST) { + auto const arr_size = + get_element(*s->col.leaf_column, val_idx).size_bytes(); + // the lengths are assumed to be INT32, check for overflow + if (arr_size > static_cast(std::numeric_limits::max())) { + CUDF_UNREACHABLE("byte array size exceeds 2GB"); + } + v = static_cast(arr_size); + } + len += v; + } + + packer.add_value(v, is_valid); + } + + // string_len is only valid on thread 0 + auto const string_len = block_reduce(temp_storage.reduce_storage).Sum(len); + if (t == 0) { string_data_len = string_len; } + __syncthreads(); + + // finish off the delta block and get the pointer to the end of the delta block + auto const output_ptr = packer.flush(); + + // now copy the char data + memcpy_block(output_ptr, first_string, string_data_len, t); + + finish_page_encode( + s, output_ptr + string_data_len, pages, comp_in, comp_out, comp_results, true); +} + constexpr int decide_compression_warps_in_block = 4; constexpr int decide_compression_block_size = decide_compression_warps_in_block * cudf::detail::warp_size; @@ -2906,6 +3050,13 @@ void EncodePages(device_span pages, gpuEncodeDeltaBinaryPages <<>>(pages, comp_in, comp_out, comp_results); } + if (BitAnd(kernel_mask, encode_kernel_mask::DELTA_LENGTH_BA) != 0) { + auto const strm = streams[s_idx++]; + gpuEncodePageLevels<<>>( + pages, write_v2_headers, encode_kernel_mask::DELTA_LENGTH_BA); + gpuEncodeDeltaLengthByteArrayPages + <<>>(pages, comp_in, comp_out, comp_results); + } if (BitAnd(kernel_mask, encode_kernel_mask::DICTIONARY) != 0) { auto const strm = streams[s_idx++]; gpuEncodePageLevels<<>>( diff --git a/cpp/src/io/parquet/page_hdr.cu b/cpp/src/io/parquet/page_hdr.cu index 114e47aa507..36157f725e3 100644 --- a/cpp/src/io/parquet/page_hdr.cu +++ b/cpp/src/io/parquet/page_hdr.cu @@ -156,6 +156,8 @@ __device__ decode_kernel_mask kernel_mask_for_page(PageInfo const& page, return decode_kernel_mask::DELTA_BINARY; } else if (page.encoding == Encoding::DELTA_BYTE_ARRAY) { return decode_kernel_mask::DELTA_BYTE_ARRAY; + } else if (page.encoding == Encoding::DELTA_LENGTH_BYTE_ARRAY) { + return decode_kernel_mask::DELTA_LENGTH_BA; } else if (is_string_col(chunk)) { return decode_kernel_mask::STRING; } diff --git a/cpp/src/io/parquet/page_string_decode.cu b/cpp/src/io/parquet/page_string_decode.cu index ef2e7ef42ef..d559f93f45b 100644 --- a/cpp/src/io/parquet/page_string_decode.cu +++ b/cpp/src/io/parquet/page_string_decode.cu @@ -35,6 +35,7 @@ namespace { constexpr int preprocess_block_size = 512; constexpr int decode_block_size = 128; constexpr int delta_preproc_block_size = 64; +constexpr int delta_length_block_size = 32; constexpr int rolling_buf_size = decode_block_size * 2; constexpr int preproc_buf_size = LEVEL_DECODE_BUF_SIZE; @@ -615,13 +616,12 @@ __global__ void __launch_bounds__(preprocess_block_size) gpuComputeStringPageBou {rep_runs}}; // setup page info - auto const mask = BitOr(decode_kernel_mask::STRING, decode_kernel_mask::DELTA_BYTE_ARRAY); if (!setupLocalPageInfo(s, pp, chunks, min_row, num_rows, - mask_filter{mask}, + mask_filter{STRINGS_MASK}, page_processing_stage::STRING_BOUNDS)) { return; } @@ -703,7 +703,6 @@ __global__ void __launch_bounds__(delta_preproc_block_size) gpuComputeDeltaPageS auto const [len, temp_bytes] = totalDeltaByteArraySize(data, end, start_value, end_value); if (t == 0) { - // TODO check for overflow pp->str_bytes = len; // only need temp space if we're skipping values @@ -712,6 +711,104 @@ __global__ void __launch_bounds__(delta_preproc_block_size) gpuComputeDeltaPageS } } +/** + * @brief Kernel for computing string page output size information for DELTA_LENGTH_BYTE_ARRAY + * encoding. + * + * This call ignores columns that are not DELTA_LENGTH_BYTE_ARRAY encoded. On exit the `str_bytes` + * field of the `PageInfo` struct will be populated. + * + * Currently this function only supports being called by a single warp. + * + * @param pages All pages to be decoded + * @param chunks All chunks to be decoded + * @param min_rows crop all rows below min_row + * @param num_rows Maximum number of rows to read + */ +__global__ void __launch_bounds__(delta_length_block_size) gpuComputeDeltaLengthPageStringSizes( + PageInfo* pages, device_span chunks, size_t min_row, size_t num_rows) +{ + using cudf::detail::warp_size; + using WarpReduce = cub::WarpReduce; + __shared__ typename WarpReduce::TempStorage temp_storage; + __shared__ __align__(16) page_state_s state_g; + __shared__ __align__(16) delta_binary_decoder string_lengths; + + page_state_s* const s = &state_g; + int const page_idx = blockIdx.x; + int const t = threadIdx.x; + PageInfo* const pp = &pages[page_idx]; + + // whether or not we have repetition levels (lists) + bool const has_repetition = chunks[pp->chunk_idx].max_level[level_type::REPETITION] > 0; + + // setup page info + if (!setupLocalPageInfo(s, + pp, + chunks, + min_row, + num_rows, + mask_filter{decode_kernel_mask::DELTA_LENGTH_BA}, + page_processing_stage::STRING_BOUNDS)) { + return; + } + + bool const is_bounds_pg = is_bounds_page(s, min_row, num_rows, has_repetition); + + // for DELTA_LENGTH_BYTE_ARRAY, string size is page_data_size - size_of_delta_binary_block. + // so all we need to do is skip the encoded string size info and then do pointer arithmetic, + // if this isn't a bounds page. + if (not is_bounds_pg) { + if (t == 0) { + auto const* string_start = string_lengths.find_end_of_block(s->data_start, s->data_end); + size_t len = static_cast(s->data_end - string_start); + pp->str_bytes = len; + } + } else { + // now process string info in the range [start_value, end_value) + // set up for decoding strings...can be either plain or dictionary + auto const start_value = pp->start_val; + auto const end_value = pp->end_val; + + if (t == 0) { string_lengths.init_binary_block(s->data_start, s->data_end); } + __syncwarp(); + + size_t total_bytes = 0; + + // initialize with first value (unless there are no values) + if (t == 0 && start_value == 0 && start_value < end_value) { + total_bytes = string_lengths.value_at(0); + } + + uleb128_t lane_sum = 0; + while (string_lengths.current_value_idx < end_value && + string_lengths.current_value_idx < string_lengths.num_encoded_values(true)) { + // calculate values for current mini-block + string_lengths.calc_mini_block_values(t); + + // get per lane sum for mini-block + for (uint32_t i = 0; i < string_lengths.values_per_mb; i += warp_size) { + uint32_t const idx = string_lengths.current_value_idx + i + t; + if (idx >= start_value && idx < end_value && idx < string_lengths.value_count) { + lane_sum += string_lengths.value[rolling_index(idx)]; + } + } + + if (t == 0) { string_lengths.setup_next_mini_block(true); } + __syncwarp(); + } + + // get sum for warp. + // note: warp_sum will only be valid on lane 0. + auto const warp_sum = WarpReduce(temp_storage).Sum(lane_sum); + + if (t == 0) { + total_bytes += warp_sum; + pp->str_bytes = total_bytes; + } + } +} + /** * @brief Kernel for computing string page output size information. * @@ -1030,10 +1127,9 @@ void ComputePageStringSizes(cudf::detail::hostdevice_vector& pages, } // kernel mask may contain other kernels we don't need to count - int const count_mask = - kernel_mask & BitOr(decode_kernel_mask::DELTA_BYTE_ARRAY, decode_kernel_mask::STRING); - int const nkernels = std::bitset<32>(count_mask).count(); - auto const streams = cudf::detail::fork_streams(stream, nkernels); + int const count_mask = kernel_mask & STRINGS_MASK; + int const nkernels = std::bitset<32>(count_mask).count(); + auto const streams = cudf::detail::fork_streams(stream, nkernels); int s_idx = 0; if (BitAnd(kernel_mask, decode_kernel_mask::DELTA_BYTE_ARRAY) != 0) { @@ -1041,6 +1137,11 @@ void ComputePageStringSizes(cudf::detail::hostdevice_vector& pages, gpuComputeDeltaPageStringSizes<<>>( pages.device_ptr(), chunks, min_row, num_rows); } + if (BitAnd(kernel_mask, decode_kernel_mask::DELTA_LENGTH_BA) != 0) { + dim3 dim_delta(delta_length_block_size, 1); + gpuComputeDeltaLengthPageStringSizes<<>>( + pages.device_ptr(), chunks, min_row, num_rows); + } if (BitAnd(kernel_mask, decode_kernel_mask::STRING) != 0) { gpuComputePageStringSizes<<>>( pages.device_ptr(), chunks, min_row, num_rows); diff --git a/cpp/src/io/parquet/parquet_gpu.hpp b/cpp/src/io/parquet/parquet_gpu.hpp index 7f557d092c5..18d282be855 100644 --- a/cpp/src/io/parquet/parquet_gpu.hpp +++ b/cpp/src/io/parquet/parquet_gpu.hpp @@ -82,6 +82,7 @@ constexpr bool is_supported_encoding(Encoding enc) case Encoding::RLE: case Encoding::RLE_DICTIONARY: case Encoding::DELTA_BINARY_PACKED: + case Encoding::DELTA_LENGTH_BYTE_ARRAY: case Encoding::DELTA_BYTE_ARRAY: return true; default: return false; } @@ -204,9 +205,14 @@ enum class decode_kernel_mask { GENERAL = (1 << 0), // Run catch-all decode kernel STRING = (1 << 1), // Run decode kernel for string data DELTA_BINARY = (1 << 2), // Run decode kernel for DELTA_BINARY_PACKED data - DELTA_BYTE_ARRAY = (1 << 3) // Run decode kernel for DELTA_BYTE_ARRAY encoded data + DELTA_BYTE_ARRAY = (1 << 3), // Run decode kernel for DELTA_BYTE_ARRAY encoded data + DELTA_LENGTH_BA = (1 << 4), // Run decode kernel for DELTA_LENGTH_BYTE_ARRAY encoded data }; +// mask representing all the ways in which a string can be encoded +constexpr uint32_t STRINGS_MASK = + BitOr(BitOr(decode_kernel_mask::DELTA_BYTE_ARRAY, decode_kernel_mask::STRING), + decode_kernel_mask::DELTA_LENGTH_BA); /** * @brief Nesting information specifically needed by the decode and preprocessing * kernels. @@ -474,9 +480,10 @@ constexpr uint32_t encoding_to_mask(Encoding encoding) * Used to control which encode kernels to run. */ enum class encode_kernel_mask { - PLAIN = (1 << 0), // Run plain encoding kernel - DICTIONARY = (1 << 1), // Run dictionary encoding kernel - DELTA_BINARY = (1 << 2) // Run DELTA_BINARY_PACKED encoding kernel + PLAIN = (1 << 0), // Run plain encoding kernel + DICTIONARY = (1 << 1), // Run dictionary encoding kernel + DELTA_BINARY = (1 << 2), // Run DELTA_BINARY_PACKED encoding kernel + DELTA_LENGTH_BA = (1 << 3), // Run DELTA_LENGTH_BYTE_ARRAY encoding kernel }; /** @@ -752,6 +759,28 @@ void DecodeDeltaByteArray(cudf::detail::hostdevice_vector& pages, kernel_error::pointer error_code, rmm::cuda_stream_view stream); +/** + * @brief Launches kernel for reading the DELTA_LENGTH_BYTE_ARRAY column data stored in the pages + * + * The page data will be written to the output pointed to in the page's + * associated column chunk. + * + * @param[in,out] pages All pages to be decoded + * @param[in] chunks All chunks to be decoded + * @param[in] num_rows Total number of rows to read + * @param[in] min_row Minimum number of rows to read + * @param[in] level_type_size Size in bytes of the type for level decoding + * @param[out] error_code Error code for kernel failures + * @param[in] stream CUDA stream to use + */ +void DecodeDeltaLengthByteArray(cudf::detail::hostdevice_vector& pages, + cudf::detail::hostdevice_vector const& chunks, + size_t num_rows, + size_t min_row, + int level_type_size, + kernel_error::pointer error_code, + rmm::cuda_stream_view stream); + /** * @brief Launches kernel for initializing encoder row group fragments * diff --git a/cpp/src/io/parquet/reader_impl.cpp b/cpp/src/io/parquet/reader_impl.cpp index 6e799424d01..c1082c0305a 100644 --- a/cpp/src/io/parquet/reader_impl.cpp +++ b/cpp/src/io/parquet/reader_impl.cpp @@ -56,8 +56,7 @@ void reader::impl::decode_page_data(size_t skip_rows, size_t num_rows) // doing a gather operation later on. // TODO: This step is somewhat redundant if size info has already been calculated (nested schema, // chunked reader). - auto const has_strings = - (kernel_mask & BitOr(decode_kernel_mask::STRING, decode_kernel_mask::DELTA_BYTE_ARRAY)) != 0; + auto const has_strings = (kernel_mask & STRINGS_MASK) != 0; std::vector col_sizes(_input_columns.size(), 0L); if (has_strings) { ComputePageStringSizes( @@ -190,6 +189,12 @@ void reader::impl::decode_page_data(size_t skip_rows, size_t num_rows) pages, chunks, num_rows, skip_rows, level_type_size, error_code.data(), streams[s_idx++]); } + // launch delta length byte array decoder + if (BitAnd(kernel_mask, decode_kernel_mask::DELTA_LENGTH_BA) != 0) { + DecodeDeltaLengthByteArray( + pages, chunks, num_rows, skip_rows, level_type_size, error_code.data(), streams[s_idx++]); + } + // launch delta binary decoder if (BitAnd(kernel_mask, decode_kernel_mask::DELTA_BINARY) != 0) { DecodeDeltaBinary( diff --git a/cpp/src/io/utilities/block_utils.cuh b/cpp/src/io/utilities/block_utils.cuh index f028b0bb367..e285936e49e 100644 --- a/cpp/src/io/utilities/block_utils.cuh +++ b/cpp/src/io/utilities/block_utils.cuh @@ -157,7 +157,7 @@ inline __device__ void memcpy_block(void* dstv, void const* srcv, uint32_t len, uint32_t align_len = min(dst_align_bytes, len); uint8_t b; if (t < align_len) { b = src[t]; } - if (sync_before_store) { __syncthreads(); } + if constexpr (sync_before_store) { __syncthreads(); } if (t < align_len) { dst[t] = b; } src += align_len; dst += align_len; @@ -173,7 +173,7 @@ inline __device__ void memcpy_block(void* dstv, void const* srcv, uint32_t len, v = src32[t]; if (src_align_bits != 0) { v = __funnelshift_r(v, src32[t + 1], src_align_bits); } } - if (sync_before_store) { __syncthreads(); } + if constexpr (sync_before_store) { __syncthreads(); } if (t < copy_cnt) { reinterpret_cast(dst)[t] = v; } src += copy_cnt * 4; dst += copy_cnt * 4; @@ -182,7 +182,7 @@ inline __device__ void memcpy_block(void* dstv, void const* srcv, uint32_t len, if (len != 0) { uint8_t b; if (t < len) { b = src[t]; } - if (sync_before_store) { __syncthreads(); } + if constexpr (sync_before_store) { __syncthreads(); } if (t < len) { dst[t] = b; } } } diff --git a/cpp/tests/io/parquet_test.cpp b/cpp/tests/io/parquet_test.cpp index 39a4a89af92..785a398d716 100644 --- a/cpp/tests/io/parquet_test.cpp +++ b/cpp/tests/io/parquet_test.cpp @@ -780,8 +780,9 @@ TEST_P(ParquetV2Test, Strings) cudf::test::expect_metadata_equal(expected_metadata, result.metadata); } -TEST_F(ParquetWriterTest, StringsAsBinary) +TEST_P(ParquetV2Test, StringsAsBinary) { + auto const is_v2 = GetParam(); std::vector unicode_strings{ "Monday", "Wȅdnȅsday", "Friday", "Monday", "Friday", "Friday", "Friday", "Funday"}; std::vector ascii_strings{ @@ -815,11 +816,13 @@ TEST_F(ParquetWriterTest, StringsAsBinary) expected_metadata.column_metadata[1].set_name("col_string").set_output_as_binary(true); expected_metadata.column_metadata[2].set_name("col_another").set_output_as_binary(true); expected_metadata.column_metadata[3].set_name("col_binary"); - expected_metadata.column_metadata[4].set_name("col_binary"); + expected_metadata.column_metadata[4].set_name("col_binary2"); auto filepath = temp_env->get_temp_filepath("BinaryStrings.parquet"); cudf::io::parquet_writer_options out_opts = cudf::io::parquet_writer_options::builder(cudf::io::sink_info{filepath}, write_tbl) + .write_v2_headers(is_v2) + .dictionary_policy(cudf::io::dictionary_policy::NEVER) .metadata(expected_metadata); cudf::io::write_parquet(out_opts); @@ -7085,4 +7088,217 @@ TEST_F(ParquetReaderTest, RepeatedNoAnnotations) CUDF_TEST_EXPECT_TABLES_EQUAL(result.tbl->view(), expected); } +inline auto random_validity(std::mt19937& engine) +{ + static std::bernoulli_distribution bn(0.7f); + return cudf::detail::make_counting_transform_iterator(0, [&](int index) { return bn(engine); }); +} + +template +std::unique_ptr make_parquet_list_col(std::mt19937& engine, + int num_rows, + int max_vals_per_row, + bool include_validity) +{ + std::vector row_sizes(num_rows); + + auto const min_values_per_row = include_validity ? 0 : 1; + std::uniform_int_distribution dist{min_values_per_row, max_vals_per_row}; + std::generate_n(row_sizes.begin(), num_rows, [&]() { return cudf::size_type{dist(engine)}; }); + + std::vector offsets(num_rows + 1); + std::exclusive_scan(row_sizes.begin(), row_sizes.end(), offsets.begin(), 0); + offsets[num_rows] = offsets[num_rows - 1] + row_sizes.back(); + + std::vector values = random_values(offsets[num_rows]); + cudf::test::fixed_width_column_wrapper offsets_col(offsets.begin(), + offsets.end()); + + if (include_validity) { + auto valids = random_validity(engine); + auto values_col = + cudf::test::fixed_width_column_wrapper(values.begin(), values.end(), valids); + auto [null_mask, null_count] = cudf::test::detail::make_null_mask(valids, valids + num_rows); + + auto col = cudf::make_lists_column( + num_rows, offsets_col.release(), values_col.release(), null_count, std::move(null_mask)); + return cudf::purge_nonempty_nulls(*col); + } else { + auto values_col = cudf::test::fixed_width_column_wrapper(values.begin(), values.end()); + return cudf::make_lists_column(num_rows, + offsets_col.release(), + values_col.release(), + 0, + cudf::create_null_mask(num_rows, cudf::mask_state::ALL_VALID)); + } +} + +std::vector string_values(std::mt19937& engine, int num_rows, int max_string_len) +{ + static std::uniform_int_distribution char_dist{'a', 'z'}; + static std::uniform_int_distribution strlen_dist{1, max_string_len}; + + std::vector values(num_rows); + std::generate_n(values.begin(), values.size(), [&]() { + int str_len = strlen_dist(engine); + std::string res = ""; + for (int i = 0; i < str_len; i++) { + res += char_dist(engine); + } + return res; + }); + + return values; +} + +// make a random list column, with random string lengths of 0..max_string_len, +// and up to max_vals_per_row strings in each list. +std::unique_ptr make_parquet_string_list_col(std::mt19937& engine, + int num_rows, + int max_vals_per_row, + int max_string_len, + bool include_validity) +{ + auto const range_min = include_validity ? 0 : 1; + + std::uniform_int_distribution dist{range_min, max_vals_per_row}; + + std::vector row_sizes(num_rows); + std::generate_n(row_sizes.begin(), num_rows, [&]() { return cudf::size_type{dist(engine)}; }); + + std::vector offsets(num_rows + 1); + std::exclusive_scan(row_sizes.begin(), row_sizes.end(), offsets.begin(), 0); + offsets[num_rows] = offsets[num_rows - 1] + row_sizes.back(); + + std::uniform_int_distribution strlen_dist{range_min, max_string_len}; + auto const values = string_values(engine, offsets[num_rows], max_string_len); + + cudf::test::fixed_width_column_wrapper offsets_col(offsets.begin(), + offsets.end()); + + if (include_validity) { + auto valids = random_validity(engine); + auto values_col = cudf::test::strings_column_wrapper(values.begin(), values.end(), valids); + auto [null_mask, null_count] = cudf::test::detail::make_null_mask(valids, valids + num_rows); + + auto col = cudf::make_lists_column( + num_rows, offsets_col.release(), values_col.release(), null_count, std::move(null_mask)); + return cudf::purge_nonempty_nulls(*col); + } else { + auto values_col = cudf::test::strings_column_wrapper(values.begin(), values.end()); + return cudf::make_lists_column(num_rows, + offsets_col.release(), + values_col.release(), + 0, + cudf::create_null_mask(num_rows, cudf::mask_state::ALL_VALID)); + } +} + +TEST_F(ParquetReaderTest, DeltaSkipRowsWithNulls) +{ + constexpr int num_rows = 50'000; + constexpr auto seed = 21337; + + std::mt19937 engine{seed}; + auto int32_list_nulls = make_parquet_list_col(engine, num_rows, 5, true); + auto int32_list = make_parquet_list_col(engine, num_rows, 5, false); + auto int64_list_nulls = make_parquet_list_col(engine, num_rows, 5, true); + auto int64_list = make_parquet_list_col(engine, num_rows, 5, false); + auto int16_list_nulls = make_parquet_list_col(engine, num_rows, 5, true); + auto int16_list = make_parquet_list_col(engine, num_rows, 5, false); + auto int8_list_nulls = make_parquet_list_col(engine, num_rows, 5, true); + auto int8_list = make_parquet_list_col(engine, num_rows, 5, false); + + auto str_list_nulls = make_parquet_string_list_col(engine, num_rows, 5, 32, true); + auto str_list = make_parquet_string_list_col(engine, num_rows, 5, 32, false); + auto big_str_list_nulls = make_parquet_string_list_col(engine, num_rows, 5, 256, true); + auto big_str_list = make_parquet_string_list_col(engine, num_rows, 5, 256, false); + + auto int32_data = random_values(num_rows); + auto int64_data = random_values(num_rows); + auto int16_data = random_values(num_rows); + auto int8_data = random_values(num_rows); + auto str_data = string_values(engine, num_rows, 32); + auto big_str_data = string_values(engine, num_rows, 256); + + auto const validity = random_validity(engine); + auto const no_nulls = cudf::test::iterators::no_nulls(); + column_wrapper int32_nulls_col{int32_data.begin(), int32_data.end(), validity}; + column_wrapper int32_col{int32_data.begin(), int32_data.end(), no_nulls}; + column_wrapper int64_nulls_col{int64_data.begin(), int64_data.end(), validity}; + column_wrapper int64_col{int64_data.begin(), int64_data.end(), no_nulls}; + + auto str_col = cudf::test::strings_column_wrapper(str_data.begin(), str_data.end(), no_nulls); + auto str_col_nulls = cudf::purge_nonempty_nulls( + cudf::test::strings_column_wrapper(str_data.begin(), str_data.end(), validity)); + auto big_str_col = + cudf::test::strings_column_wrapper(big_str_data.begin(), big_str_data.end(), no_nulls); + auto big_str_col_nulls = cudf::purge_nonempty_nulls( + cudf::test::strings_column_wrapper(big_str_data.begin(), big_str_data.end(), validity)); + + cudf::table_view tbl({int32_col, int32_nulls_col, *int32_list, *int32_list_nulls, + int64_col, int64_nulls_col, *int64_list, *int64_list_nulls, + *int16_list, *int16_list_nulls, *int8_list, *int8_list_nulls, + str_col, *str_col_nulls, *str_list, *str_list_nulls, + big_str_col, *big_str_col_nulls, *big_str_list, *big_str_list_nulls}); + + auto const filepath = temp_env->get_temp_filepath("DeltaSkipRowsWithNulls.parquet"); + auto const out_opts = + cudf::io::parquet_writer_options::builder(cudf::io::sink_info{filepath}, tbl) + .stats_level(cudf::io::statistics_freq::STATISTICS_COLUMN) + .compression(cudf::io::compression_type::NONE) + .dictionary_policy(cudf::io::dictionary_policy::NEVER) + .max_page_size_rows(20'000) + .write_v2_headers(true) + .build(); + cudf::io::write_parquet(out_opts); + + // skip_rows / num_rows + // clang-format off + std::vector> params{ + // skip and then read rest of file + {-1, -1}, {1, -1}, {2, -1}, {32, -1}, {33, -1}, {128, -1}, {1000, -1}, + // no skip but read fewer rows + {0, 1}, {0, 2}, {0, 31}, {0, 32}, {0, 33}, {0, 128}, {0, 129}, {0, 130}, + // skip and truncate + {1, 32}, {1, 33}, {32, 32}, {33, 139}, + // cross page boundaries + {10'000, 20'000} + }; + + // clang-format on + for (auto p : params) { + cudf::io::parquet_reader_options read_args = + cudf::io::parquet_reader_options::builder(cudf::io::source_info{filepath}); + if (p.first >= 0) { read_args.set_skip_rows(p.first); } + if (p.second >= 0) { read_args.set_num_rows(p.second); } + auto result = cudf::io::read_parquet(read_args); + + p.first = p.first < 0 ? 0 : p.first; + p.second = p.second < 0 ? num_rows - p.first : p.second; + std::vector slice_indices{p.first, p.first + p.second}; + std::vector expected = cudf::slice(tbl, slice_indices); + + CUDF_TEST_EXPECT_TABLES_EQUAL(result.tbl->view(), expected[0]); + + // test writing the result back out as a further check of the delta writer's correctness + std::vector out_buffer; + cudf::io::parquet_writer_options out_opts2 = + cudf::io::parquet_writer_options::builder(cudf::io::sink_info{&out_buffer}, + result.tbl->view()) + .stats_level(cudf::io::statistics_freq::STATISTICS_COLUMN) + .compression(cudf::io::compression_type::NONE) + .dictionary_policy(cudf::io::dictionary_policy::NEVER) + .max_page_size_rows(20'000) + .write_v2_headers(true); + cudf::io::write_parquet(out_opts2); + + cudf::io::parquet_reader_options default_in_opts = cudf::io::parquet_reader_options::builder( + cudf::io::source_info{out_buffer.data(), out_buffer.size()}); + auto const result2 = cudf::io::read_parquet(default_in_opts); + + CUDF_TEST_EXPECT_TABLES_EQUAL(result.tbl->view(), result2.tbl->view()); + } +} + CUDF_TEST_PROGRAM_MAIN() diff --git a/python/cudf/cudf/tests/data/parquet/delta_encoding.parquet b/python/cudf/cudf/tests/data/parquet/delta_encoding.parquet index 29565bef4d2..ea6952e5bcd 100644 Binary files a/python/cudf/cudf/tests/data/parquet/delta_encoding.parquet and b/python/cudf/cudf/tests/data/parquet/delta_encoding.parquet differ diff --git a/python/cudf/cudf/tests/test_parquet.py b/python/cudf/cudf/tests/test_parquet.py index c5da03d2942..5c9e3aa3d9f 100644 --- a/python/cudf/cudf/tests/test_parquet.py +++ b/python/cudf/cudf/tests/test_parquet.py @@ -1352,8 +1352,13 @@ def test_delta_binary(nrows, add_nulls, dtype, tmpdir): @pytest.mark.parametrize("nrows", delta_num_rows()) @pytest.mark.parametrize("add_nulls", [True, False]) -@pytest.mark.parametrize("str_encoding", ["DELTA_BYTE_ARRAY"]) -def test_delta_byte_array_roundtrip(nrows, add_nulls, str_encoding, tmpdir): +@pytest.mark.parametrize("max_string_length", [12, 48, 96, 128]) +@pytest.mark.parametrize( + "str_encoding", ["DELTA_BYTE_ARRAY", "DELTA_LENGTH_BYTE_ARRAY"] +) +def test_delta_byte_array_roundtrip( + nrows, add_nulls, max_string_length, str_encoding, tmpdir +): null_frequency = 0.25 if add_nulls else 0 # Create a pandas dataframe with random data of mixed lengths @@ -1363,13 +1368,7 @@ def test_delta_byte_array_roundtrip(nrows, add_nulls, str_encoding, tmpdir): "dtype": "str", "null_frequency": null_frequency, "cardinality": nrows, - "max_string_length": 10, - }, - { - "dtype": "str", - "null_frequency": null_frequency, - "cardinality": nrows, - "max_string_length": 100, + "max_string_length": max_string_length, }, ], rows=nrows, @@ -1391,10 +1390,24 @@ def test_delta_byte_array_roundtrip(nrows, add_nulls, str_encoding, tmpdir): pcdf = cudf.from_pandas(test_pdf) assert_eq(cdf, pcdf) + # Test DELTA_LENGTH_BYTE_ARRAY writing as well + if str_encoding == "DELTA_LENGTH_BYTE_ARRAY": + cudf_fname = tmpdir.join("cdfdeltaba.parquet") + pcdf.to_parquet( + cudf_fname, + compression="snappy", + header_version="2.0", + use_dictionary=False, + ) + cdf2 = cudf.from_pandas(pd.read_parquet(cudf_fname)) + assert_eq(cdf2, cdf) + @pytest.mark.parametrize("nrows", delta_num_rows()) @pytest.mark.parametrize("add_nulls", [True, False]) -@pytest.mark.parametrize("str_encoding", ["DELTA_BYTE_ARRAY"]) +@pytest.mark.parametrize( + "str_encoding", ["DELTA_BYTE_ARRAY", "DELTA_LENGTH_BYTE_ARRAY"] +) def test_delta_struct_list(tmpdir, nrows, add_nulls, str_encoding): # Struct> lists_per_row = 3 @@ -1441,7 +1454,20 @@ def string_list_gen_wrapped(x, y): # sanity check to verify file is written properly assert_eq(test_pdf, pd.read_parquet(pdf_fname)) cdf = cudf.read_parquet(pdf_fname) - assert_eq(cdf, cudf.from_pandas(test_pdf)) + pcdf = cudf.from_pandas(test_pdf) + assert_eq(cdf, pcdf) + + # Test DELTA_LENGTH_BYTE_ARRAY writing as well + if str_encoding == "DELTA_LENGTH_BYTE_ARRAY": + cudf_fname = tmpdir.join("cdfdeltaba.parquet") + pcdf.to_parquet( + cudf_fname, + compression="snappy", + header_version="2.0", + use_dictionary=False, + ) + cdf2 = cudf.from_pandas(pd.read_parquet(cudf_fname)) + assert_eq(cdf2, cdf) @pytest.mark.parametrize(