From abd247483e090d4e354d3dafcdac9263e8008a99 Mon Sep 17 00:00:00 2001 From: Yunsong Wang Date: Fri, 27 May 2022 16:27:59 -0400 Subject: [PATCH] Improve static_map::retrieve_all: use allocator to handle temporary memory allocation --- include/cuco/detail/static_map.inl | 41 +++++++++++++++++++++++++----- 1 file changed, 35 insertions(+), 6 deletions(-) diff --git a/include/cuco/detail/static_map.inl b/include/cuco/detail/static_map.inl index 3a8fea931..23d797cae 100644 --- a/include/cuco/detail/static_map.inl +++ b/include/cuco/detail/static_map.inl @@ -18,12 +18,12 @@ #include #include -#include -#include #include #include #include +#include + namespace cuco { template @@ -220,10 +220,39 @@ std::pair static_map::retrieve_a auto filled = detail::slot_is_filled{get_empty_key_sentinel()}; auto zipped_out_begin = thrust::make_zip_iterator(thrust::make_tuple(keys_out, values_out)); - auto const zipped_out_end = - thrust::copy_if(thrust::cuda::par.on(stream), begin, end, zipped_out_begin, filled); - auto const num = std::distance(zipped_out_begin, zipped_out_end); - return std::make_pair(keys_out + num, values_out + num); + std::size_t temp_storage_bytes = 0; + using temp_allocator_type = typename std::allocator_traits::rebind_alloc; + auto temp_allocator = temp_allocator_type{slot_allocator_}; + auto d_num_out = reinterpret_cast( + std::allocator_traits::allocate(temp_allocator, sizeof(std::size_t))); + cub::DeviceSelect::If(nullptr, + temp_storage_bytes, + begin, + zipped_out_begin, + d_num_out, + get_capacity(), + filled, + stream); + + // Allocate temporary storage + auto d_temp_storage = + std::allocator_traits::allocate(temp_allocator, temp_storage_bytes); + + cub::DeviceSelect::If(d_temp_storage, + temp_storage_bytes, + begin, + zipped_out_begin, + d_num_out, + get_capacity(), + filled, + stream); + + std::size_t h_num_out; + CUCO_CUDA_TRY( + cudaMemcpyAsync(&h_num_out, d_num_out, sizeof(std::size_t), cudaMemcpyDeviceToHost, stream)); + CUCO_CUDA_TRY(cudaStreamSynchronize(stream)); + + return std::make_pair(keys_out + h_num_out, values_out + h_num_out); } template