Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use nvcomp's snappy compressor in parquet writer #8229

Merged
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
db23741
Initial changes to get nvcomp integrated
devavret May 7, 2021
a5f3363
Using nvcomp provided max compressed buffer size
devavret May 12, 2021
61018aa
Recover from error in nvcomp compressing and encode uncompressed.
devavret May 12, 2021
64d7d1c
review changes
devavret May 13, 2021
27764e7
Replace accidental vector with uvector.
devavret May 14, 2021
95a57ec
Provide the actual max uncomp page size to nvcomp's temp size estimat…
devavret May 14, 2021
cc9500a
cmake changes requested in review
devavret May 14, 2021
7989b9c
Merge branch 'branch-21.10' into parquet-writer-nvcomp-snappy
devavret Aug 19, 2021
f90409c
Merge branch 'branch-21.10' into parquet-writer-nvcomp-snappy
devavret Aug 19, 2021
40ebd1e
Update parquet writer to use nvcomp 2.1
devavret Aug 24, 2021
4a2cb24
One more cmake change related to updating nvcomp
devavret Aug 24, 2021
6019b0f
Update nvcomp to version with fix for snappy decompressor
devavret Aug 31, 2021
140d3d0
Fix allocation size bug
devavret Sep 2, 2021
05f5343
Merge branch 'branch-21.10' into parquet-writer-nvcomp-snappy
devavret Sep 3, 2021
62d92b4
Update cmake to find nvcomp in new manner
devavret Sep 3, 2021
3c73be3
Make nvcomp private in cmake and update get_nvcomp
devavret Sep 7, 2021
e0a013d
Add an env var flip switch to choose b/w nvcomp and inbuilt compressor
devavret Sep 8, 2021
7501b11
Merge branch 'branch-21.10' into parquet-writer-nvcomp-snappy
devavret Sep 8, 2021
bfa1366
Static linking nvcomp into libcudf
devavret Sep 8, 2021
203cf15
Review changes
devavret Sep 9, 2021
6721fb8
Merge changes from nvcomp -fPIC
devavret Sep 13, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ rapids_find_package(Threads REQUIRED
rapids_cpm_init()
# find jitify
include(cmake/thirdparty/get_jitify.cmake)
# find nvCOMP
include(cmake/thirdparty/get_nvcomp.cmake)
# find thrust/cub
include(cmake/thirdparty/get_thrust.cmake)
# find rmm
Expand Down Expand Up @@ -503,13 +505,16 @@ target_compile_definitions(cudf PUBLIC "SPDLOG_ACTIVE_LEVEL=SPDLOG_LEVEL_${RMM_L
# Compile stringified JIT sources first
add_dependencies(cudf jitify_preprocess_run)

set_target_properties(nvcomp PROPERTIES POSITION_INDEPENDENT_CODE ON)
devavret marked this conversation as resolved.
Show resolved Hide resolved

# Specify the target module library dependencies
target_link_libraries(cudf
PUBLIC ZLIB::ZLIB
${ARROW_LIBRARIES}
cudf::Thrust
rmm::rmm
PRIVATE cuco::cuco)
PRIVATE cuco::cuco
nvcomp::nvcomp)
kkraus14 marked this conversation as resolved.
Show resolved Hide resolved

# Add Conda library, and include paths if specified
if(TARGET conda_env)
Expand Down
40 changes: 40 additions & 0 deletions cpp/cmake/thirdparty/get_nvcomp.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#=============================================================================
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#=============================================================================

function(find_and_configure_nvcomp VERSION)

# Find or install nvcomp
rapids_cpm_find(nvcomp ${VERSION}
GLOBAL_TARGETS nvcomp::nvcomp
CPM_ARGS
GITHUB_REPOSITORY NVIDIA/nvcomp
GIT_TAG 3a12516afdeab4ace01298031757f84b8dda81b7
# GIT_SHALLOW TRUE
OPTIONS "BUILD_STATIC ON"
"BUILD_TESTS OFF"
"BUILD_BENCHMARKS OFF"
"BUILD_EXAMPLES OFF"
)

if(NOT TARGET nvcomp::nvcomp)
add_library(nvcomp::nvcomp ALIAS nvcomp)
endif()

endfunction()

set(CUDF_MIN_VERSION_nvCOMP 2.1.0)

find_and_configure_nvcomp(${CUDF_MIN_VERSION_nvCOMP})
2 changes: 1 addition & 1 deletion cpp/src/io/comp/snap.cu
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ static __device__ uint32_t Match60(const uint8_t* src1,
* @param[out] outputs Compression status per block
* @param[in] count Number of blocks to compress
*/
extern "C" __global__ void __launch_bounds__(128)
__global__ void __launch_bounds__(128)
snap_kernel(gpu_inflate_input_s* inputs, gpu_inflate_status_s* outputs, int count)
{
__shared__ __align__(16) snap_state_s state_g;
Expand Down
18 changes: 14 additions & 4 deletions cpp/src/io/parquet/page_enc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ __global__ void __launch_bounds__(128)
device_span<parquet_column_device_view const> col_desc,
statistics_merge_group* page_grstats,
statistics_merge_group* chunk_grstats,
size_t max_page_comp_data_size,
int32_t num_columns)
{
// TODO: All writing seems to be done by thread 0. Could be replaced by thrust foreach
Expand Down Expand Up @@ -270,6 +271,8 @@ __global__ void __launch_bounds__(128)
uint32_t page_offset = ck_g.ck_stat_size;
uint32_t num_dict_entries = 0;
uint32_t comp_page_offset = ck_g.ck_stat_size;
uint32_t page_headers_size = 0;
uint32_t max_page_data_size = 0;
uint32_t cur_row = ck_g.start_row;
uint32_t ck_max_stats_len = 0;
uint32_t max_stats_len = 0;
Expand All @@ -295,7 +298,9 @@ __global__ void __launch_bounds__(128)
page_g.num_leaf_values = ck_g.num_dict_entries;
page_g.num_values = ck_g.num_dict_entries; // TODO: shouldn't matter for dict page
page_offset += page_g.max_hdr_size + page_g.max_data_size;
comp_page_offset += page_g.max_hdr_size + GetMaxCompressedBfrSize(page_g.max_data_size);
comp_page_offset += page_g.max_hdr_size + max_page_comp_data_size;
page_headers_size += page_g.max_hdr_size;
max_page_data_size = max(max_page_data_size, page_g.max_data_size);
}
__syncwarp();
if (t == 0) {
Expand Down Expand Up @@ -378,7 +383,9 @@ __global__ void __launch_bounds__(128)
pagestats_g.start_chunk = ck_g.first_fragment + page_start;
pagestats_g.num_chunks = page_g.num_fragments;
page_offset += page_g.max_hdr_size + page_g.max_data_size;
comp_page_offset += page_g.max_hdr_size + GetMaxCompressedBfrSize(page_g.max_data_size);
comp_page_offset += page_g.max_hdr_size + max_page_comp_data_size;
page_headers_size += page_g.max_hdr_size;
max_page_data_size = max(max_page_data_size, page_g.max_data_size);
cur_row += rows_in_page;
ck_max_stats_len = max(ck_max_stats_len, max_stats_len);
}
Expand Down Expand Up @@ -416,7 +423,8 @@ __global__ void __launch_bounds__(128)
}
ck_g.num_pages = num_pages;
ck_g.bfr_size = page_offset;
ck_g.compressed_size = comp_page_offset;
ck_g.page_headers_size = page_headers_size;
ck_g.max_page_data_size = max_page_data_size;
pagestats_g.start_chunk = ck_g.first_page + ck_g.use_dictionary; // Exclude dictionary
pagestats_g.num_chunks = num_pages - ck_g.use_dictionary;
}
Expand Down Expand Up @@ -1973,6 +1981,7 @@ void InitFragmentStatistics(device_2dspan<statistics_group> groups,
* @param[in] num_columns Number of columns
* @param[out] page_grstats Setup for page-level stats
* @param[out] chunk_grstats Setup for chunk-level stats
* @param[in] max_page_comp_data_size Calculated maximum compressed data size of pages
* @param[in] stream CUDA stream to use, default 0
*/
void InitEncoderPages(device_2dspan<EncColumnChunk> chunks,
Expand All @@ -1981,12 +1990,13 @@ void InitEncoderPages(device_2dspan<EncColumnChunk> chunks,
int32_t num_columns,
statistics_merge_group* page_grstats,
statistics_merge_group* chunk_grstats,
size_t max_page_comp_data_size,
rmm::cuda_stream_view stream)
{
auto num_rowgroups = chunks.size().first;
dim3 dim_grid(num_columns, num_rowgroups); // 1 threadblock per rowgroup
gpuInitPages<<<dim_grid, 128, 0, stream.value()>>>(
chunks, pages, col_desc, page_grstats, chunk_grstats, num_columns);
chunks, pages, col_desc, page_grstats, chunk_grstats, max_page_comp_data_size, num_columns);
}

/**
Expand Down
27 changes: 15 additions & 12 deletions cpp/src/io/parquet/parquet_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,8 @@ struct EncColumnChunk {
statistics_chunk const* stats; //!< Fragment statistics
uint32_t bfr_size; //!< Uncompressed buffer size
uint32_t compressed_size; //!< Compressed buffer size
uint32_t max_page_data_size; //!< Max data size (excuding header) of any page in this chunk
uint32_t page_headers_size; //!< Sum of size of all page headers
uint32_t start_row; //!< First row of chunk
uint32_t num_rows; //!< Number of rows in chunk
size_type num_values; //!< Number of values in chunk. Different from num_rows for nested types
Expand Down Expand Up @@ -538,15 +540,17 @@ void get_dictionary_indices(cudf::detail::device_2dspan<EncColumnChunk> chunks,
* @param[in] num_columns Number of columns
* @param[in] page_grstats Setup for page-level stats
* @param[in] chunk_grstats Setup for chunk-level stats
* @param[in] max_page_comp_data_size Calculated maximum compressed data size of pages
* @param[in] stream CUDA stream to use, default 0
*/
void InitEncoderPages(cudf::detail::device_2dspan<EncColumnChunk> chunks,
device_span<gpu::EncPage> pages,
device_span<parquet_column_device_view const> col_desc,
int32_t num_columns,
statistics_merge_group* page_grstats = nullptr,
statistics_merge_group* chunk_grstats = nullptr,
rmm::cuda_stream_view stream = rmm::cuda_stream_default);
statistics_merge_group* page_grstats,
statistics_merge_group* chunk_grstats,
size_t max_page_comp_data_size,
rmm::cuda_stream_view stream);

/**
* @brief Launches kernel for packing column data into parquet pages
Expand All @@ -557,18 +561,17 @@ void InitEncoderPages(cudf::detail::device_2dspan<EncColumnChunk> chunks,
* @param[in] stream CUDA stream to use, default 0
*/
void EncodePages(device_span<EncPage> pages,
device_span<gpu_inflate_input_s> comp_in = {},
device_span<gpu_inflate_status_s> comp_out = {},
rmm::cuda_stream_view stream = rmm::cuda_stream_default);
device_span<gpu_inflate_input_s> comp_in,
device_span<gpu_inflate_status_s> comp_out,
rmm::cuda_stream_view stream);

/**
* @brief Launches kernel to make the compressed vs uncompressed chunk-level decision
*
* @param[in,out] chunks Column chunks (updated with actual compressed/uncompressed sizes)
* @param[in] stream CUDA stream to use, default 0
*/
void DecideCompression(device_span<EncColumnChunk> chunks,
rmm::cuda_stream_view stream = rmm::cuda_stream_default);
void DecideCompression(device_span<EncColumnChunk> chunks, rmm::cuda_stream_view stream);

/**
* @brief Launches kernel to encode page headers
Expand All @@ -580,10 +583,10 @@ void DecideCompression(device_span<EncColumnChunk> chunks,
* @param[in] stream CUDA stream to use, default 0
*/
void EncodePageHeaders(device_span<EncPage> pages,
device_span<gpu_inflate_status_s const> comp_out = {},
device_span<statistics_chunk const> page_stats = {},
const statistics_chunk* chunk_stats = nullptr,
rmm::cuda_stream_view stream = rmm::cuda_stream_default);
device_span<gpu_inflate_status_s const> comp_out,
device_span<statistics_chunk const> page_stats,
const statistics_chunk* chunk_stats,
rmm::cuda_stream_view stream);

/**
* @brief Launches kernel to gather pages to a single contiguous block per chunk
Expand Down
Loading