Skip to content

Commit

Permalink
Enable ZSTD compression in ORC and Parquet writers (#11551)
Browse files Browse the repository at this point in the history
Closes #9058, #9056

Expands nvCOMP adapter to include ZSTD compression.
Adds centralized nvCOMP policy. `is_compression_enabled`.
Adds centralized nvCOMP alignment utility, `compress_input_alignment_bits`.
Adds centralized nvCOMP utility to get the maximum supported compression chunk size - `batched_compress_max_allowed_chunk_size`.
Encoded ORC row groups are aligned based on compression requirements.
Encoded Parquet pages are aligned based on compression requirements.
Parquet fragment size now scales with the page size to better fit the default page size with ZSTD compression.
Small refactoring around `decompress_status` for improved type safety and hopefully naming.
Replaced `snappy_compress` from the Parquet writer with the nvCOMP adapter call.
Vectors of `compression_result`s are initialized before compression to avoid issues with random chunk skipping due to uninitialized memory.

Authors:
  - Vukasin Milovanovic (https://github.com/vuule)

Approvers:
  - Jason Lowe (https://github.com/jlowe)
  - Jim Brennan (https://github.com/jbrennan333)
  - Mike Wilson (https://github.com/hyperbolic2346)
  - Tobias Ribizel (https://github.com/upsj)
  - Matthew Roeschke (https://github.com/mroeschke)

URL: #11551
  • Loading branch information
vuule authored Sep 12, 2022
1 parent 39ad65f commit 578e65f
Show file tree
Hide file tree
Showing 31 changed files with 686 additions and 405 deletions.
12 changes: 7 additions & 5 deletions cpp/src/io/avro/reader_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,11 @@ rmm::device_buffer decompress_data(datasource& source,
if (meta.codec == "deflate") {
auto inflate_in = hostdevice_vector<device_span<uint8_t const>>(meta.block_list.size(), stream);
auto inflate_out = hostdevice_vector<device_span<uint8_t>>(meta.block_list.size(), stream);
auto inflate_stats = hostdevice_vector<decompress_status>(meta.block_list.size(), stream);
auto inflate_stats = hostdevice_vector<compression_result>(meta.block_list.size(), stream);
thrust::fill(rmm::exec_policy(stream),
inflate_stats.d_begin(),
inflate_stats.d_end(),
compression_result{0, compression_status::FAILURE});

// Guess an initial maximum uncompressed block size. We estimate the compression factor is two
// and round up to the next multiple of 4096 bytes.
Expand All @@ -190,8 +194,6 @@ rmm::device_buffer decompress_data(datasource& source,

for (int loop_cnt = 0; loop_cnt < 2; loop_cnt++) {
inflate_out.host_to_device(stream);
CUDF_CUDA_TRY(cudaMemsetAsync(
inflate_stats.device_ptr(), 0, inflate_stats.memory_size(), stream.value()));
gpuinflate(inflate_in, inflate_out, inflate_stats, gzip_header_included::NO, stream);
inflate_stats.device_to_host(stream, true);

Expand All @@ -204,9 +206,9 @@ rmm::device_buffer decompress_data(datasource& source,
inflate_stats.begin(),
std::back_inserter(actual_uncomp_sizes),
[](auto const& inf_out, auto const& inf_stats) {
// If error status is 1 (buffer too small), the `bytes_written` field
// If error status is OUTPUT_OVERFLOW, the `bytes_written` field
// actually contains the uncompressed data size
return inf_stats.status == 1
return inf_stats.status == compression_status::OUTPUT_OVERFLOW
? std::max(inf_out.size(), inf_stats.bytes_written)
: inf_out.size();
});
Expand Down
15 changes: 8 additions & 7 deletions cpp/src/io/comp/debrotli.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1906,15 +1906,15 @@ static __device__ void ProcessCommands(debrotli_state_s* s, const brotli_diction
*
* @param[in] inputs Source buffer per block
* @param[out] outputs Destination buffer per block
* @param[out] statuses Decompressor status per block
* @param[out] results Decompressor status per block
* @param scratch Intermediate device memory heap space (will be dynamically shared between blocks)
* @param scratch_size Size of scratch heap space (smaller sizes may result in serialization between
* blocks)
*/
__global__ void __launch_bounds__(block_size, 2)
gpu_debrotli_kernel(device_span<device_span<uint8_t const> const> inputs,
device_span<device_span<uint8_t> const> outputs,
device_span<decompress_status> statuses,
device_span<compression_result> results,
uint8_t* scratch,
uint32_t scratch_size)
{
Expand Down Expand Up @@ -2016,10 +2016,11 @@ __global__ void __launch_bounds__(block_size, 2)
__syncthreads();
// Output decompression status
if (!t) {
statuses[block_id].bytes_written = s->out - s->outbase;
statuses[block_id].status = s->error;
results[block_id].bytes_written = s->out - s->outbase;
results[block_id].status =
(s->error == 0) ? compression_status::SUCCESS : compression_status::FAILURE;
// Return ext heap used by last block (statistics)
statuses[block_id].reserved = s->fb_size;
results[block_id].reserved = s->fb_size;
}
}

Expand Down Expand Up @@ -2079,7 +2080,7 @@ size_t __host__ get_gpu_debrotli_scratch_size(int max_num_inputs)

void gpu_debrotli(device_span<device_span<uint8_t const> const> inputs,
device_span<device_span<uint8_t> const> outputs,
device_span<decompress_status> statuses,
device_span<compression_result> results,
void* scratch,
size_t scratch_size,
rmm::cuda_stream_view stream)
Expand All @@ -2104,7 +2105,7 @@ void gpu_debrotli(device_span<device_span<uint8_t const> const> inputs,
cudaMemcpyHostToDevice,
stream.value()));
gpu_debrotli_kernel<<<dim_grid, dim_block, 0, stream.value()>>>(
inputs, outputs, statuses, scratch_u8, fb_heap_size);
inputs, outputs, results, scratch_u8, fb_heap_size);
#if DUMP_FB_HEAP
uint32_t dump[2];
uint32_t cur = 0;
Expand Down
20 changes: 13 additions & 7 deletions cpp/src/io/comp/gpuinflate.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1020,14 +1020,14 @@ __device__ int parse_gzip_header(const uint8_t* src, size_t src_size)
* @tparam block_size Thread block dimension for this call
* @param inputs Source and destination buffer information per block
* @param outputs Destination buffer information per block
* @param statuses Decompression status buffer per block
* @param results Decompression status buffer per block
* @param parse_hdr If nonzero, indicates that the compressed bitstream includes a GZIP header
*/
template <int block_size>
__global__ void __launch_bounds__(block_size)
inflate_kernel(device_span<device_span<uint8_t const> const> inputs,
device_span<device_span<uint8_t> const> outputs,
device_span<decompress_status> statuses,
device_span<compression_result> results,
gzip_header_included parse_hdr)
{
__shared__ __align__(16) inflate_state_s state_g;
Expand Down Expand Up @@ -1133,9 +1133,15 @@ __global__ void __launch_bounds__(block_size)
// Output buffer too small
state->err = 1;
}
statuses[z].bytes_written = state->out - state->outbase;
statuses[z].status = state->err;
statuses[z].reserved = (int)(state->end - state->cur); // Here mainly for debug purposes
results[z].bytes_written = state->out - state->outbase;
results[z].status = [&]() {
switch (state->err) {
case 0: return compression_status::SUCCESS;
case 1: return compression_status::OUTPUT_OVERFLOW;
default: return compression_status::FAILURE;
}
}();
results[z].reserved = (int)(state->end - state->cur); // Here mainly for debug purposes
}
}

Expand Down Expand Up @@ -1200,14 +1206,14 @@ __global__ void __launch_bounds__(1024)

void gpuinflate(device_span<device_span<uint8_t const> const> inputs,
device_span<device_span<uint8_t> const> outputs,
device_span<decompress_status> statuses,
device_span<compression_result> results,
gzip_header_included parse_hdr,
rmm::cuda_stream_view stream)
{
constexpr int block_size = 128; // Threads per block
if (inputs.size() > 0) {
inflate_kernel<block_size>
<<<inputs.size(), block_size, 0, stream.value()>>>(inputs, outputs, statuses, parse_hdr);
<<<inputs.size(), block_size, 0, stream.value()>>>(inputs, outputs, results, parse_hdr);
}
}

Expand Down
32 changes: 21 additions & 11 deletions cpp/src/io/comp/gpuinflate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,21 @@ namespace cudf {
namespace io {

/**
* @brief Output parameters for the decompression interface
* @brief Status of a compression/decompression operation.
*/
struct decompress_status {
enum class compression_status : uint8_t {
SUCCESS, ///< Successful, output is valid
FAILURE, ///< Failed, output is invalid (e.g. input is unsupported in some way)
SKIPPED, ///< Operation skipped (if conversion, uncompressed data can be used)
OUTPUT_OVERFLOW, ///< Output buffer is too small; operation can succeed with larger output
};

/**
* @brief Descriptor of compression/decompression result.
*/
struct compression_result {
uint64_t bytes_written;
uint32_t status;
compression_status status;
uint32_t reserved;
};

Expand All @@ -44,13 +54,13 @@ enum class gzip_header_included { NO, YES };
*
* @param[in] inputs List of input buffers
* @param[out] outputs List of output buffers
* @param[out] statuses List of output status structures
* @param[out] results List of output status structures
* @param[in] parse_hdr Whether or not to parse GZIP header
* @param[in] stream CUDA stream to use
*/
void gpuinflate(device_span<device_span<uint8_t const> const> inputs,
device_span<device_span<uint8_t> const> outputs,
device_span<decompress_status> statuses,
device_span<compression_result> results,
gzip_header_included parse_hdr,
rmm::cuda_stream_view stream);

Expand All @@ -73,12 +83,12 @@ void gpu_copy_uncompressed_blocks(device_span<device_span<uint8_t const> const>
*
* @param[in] inputs List of input buffers
* @param[out] outputs List of output buffers
* @param[out] statuses List of output status structures
* @param[out] results List of output status structures
* @param[in] stream CUDA stream to use
*/
void gpu_unsnap(device_span<device_span<uint8_t const> const> inputs,
device_span<device_span<uint8_t> const> outputs,
device_span<decompress_status> statuses,
device_span<compression_result> results,
rmm::cuda_stream_view stream);

/**
Expand All @@ -98,14 +108,14 @@ size_t get_gpu_debrotli_scratch_size(int max_num_inputs = 0);
*
* @param[in] inputs List of input buffers
* @param[out] outputs List of output buffers
* @param[out] statuses List of output status structures
* @param[out] results List of output status structures
* @param[in] scratch Temporary memory for intermediate work
* @param[in] scratch_size Size in bytes of the temporary memory
* @param[in] stream CUDA stream to use
*/
void gpu_debrotli(device_span<device_span<uint8_t const> const> inputs,
device_span<device_span<uint8_t> const> outputs,
device_span<decompress_status> statuses,
device_span<compression_result> results,
void* scratch,
size_t scratch_size,
rmm::cuda_stream_view stream);
Expand All @@ -118,12 +128,12 @@ void gpu_debrotli(device_span<device_span<uint8_t const> const> inputs,
*
* @param[in] inputs List of input buffers
* @param[out] outputs List of output buffers
* @param[out] statuses List of output status structures
* @param[out] results List of output status structures
* @param[in] stream CUDA stream to use
*/
void gpu_snap(device_span<device_span<uint8_t const> const> inputs,
device_span<device_span<uint8_t> const> outputs,
device_span<decompress_status> statuses,
device_span<compression_result> results,
rmm::cuda_stream_view stream);

} // namespace io
Expand Down
Loading

0 comments on commit 578e65f

Please sign in to comment.