diff --git a/include/cuco/detail/static_map/kernels.cuh b/include/cuco/detail/static_map/kernels.cuh index a58e9e273..cbf1cef06 100644 --- a/include/cuco/detail/static_map/kernels.cuh +++ b/include/cuco/detail/static_map/kernels.cuh @@ -122,6 +122,7 @@ CUCO_KERNEL __launch_bounds__(BlockSize) void insert_or_apply_shmem( Ref ref, typename SharedMapRefType::extent_type window_extent) { + static_assert(CGSize == 1, "use shared_memory kernel only if cg_size == 1"); namespace cg = cooperative_groups; using Key = typename Ref::key_type; using Value = typename Ref::mapped_type; @@ -156,52 +157,42 @@ CUCO_KERNEL __launch_bounds__(BlockSize) void insert_or_apply_shmem( block.sync(); while ((idx - thread_idx / CGSize) < n) { - if constexpr (CGSize == 1) { - int32_t inserted = 0; - int32_t local_cardinality = 0; - // insert-or-apply into the shared map first - if (idx < n) { - value_type const& insert_pair = *(first + idx); - inserted = shared_map_ref.insert_or_apply(insert_pair, op); - } - if (idx - warp_thread_idx < n) { // all threads in warp particpate - local_cardinality = cg::reduce(warp, inserted, cg::plus()); - } - if (warp_thread_idx == 0) { - block_cardinality.fetch_add(local_cardinality, cuda::memory_order_relaxed); - } - block.sync(); - if (block_cardinality > BlockSize) { break; } - } else { - auto const tile = cg::tiled_partition(block); - if (idx < n) { - value_type const& insert_pair = *(first + idx); - ref.insert_or_apply(tile, insert_pair, op); - } + int32_t inserted = 0; + int32_t local_cardinality = 0; + // insert-or-apply into the shared map first + if (idx < n) { + value_type const& insert_pair = *(first + idx); + inserted = shared_map_ref.insert_or_apply(insert_pair, op); + } + if (idx - warp_thread_idx < n) { // all threads in warp particpate + local_cardinality = cg::reduce(warp, inserted, cg::plus()); + } + if (warp_thread_idx == 0) { + block_cardinality.fetch_add(local_cardinality, cuda::memory_order_relaxed); } + block.sync(); + if (block_cardinality > BlockSize) { break; } idx += loop_stride; } - if constexpr (CGSize == 1) { - // insert-or-apply from shared map to global map - auto window_idx = thread_idx; - while (window_idx < num_windows) { - auto const slot = storage[window_idx][0]; - if (not cuco::detail::bitwise_compare(slot.first, ref.empty_key_sentinel())) { - ref.insert_or_apply(slot, op); - } - window_idx += BlockSize; + // insert-or-apply from shared map to global map + auto window_idx = thread_idx; + while (window_idx < num_windows) { + auto const slot = storage[window_idx][0]; + if (not cuco::detail::bitwise_compare(slot.first, ref.empty_key_sentinel())) { + ref.insert_or_apply(slot, op); } + window_idx += BlockSize; + } - // insert-or-apply into global map for the remaining elements whose block_cardinality - // exceeds the cardinality threshold. - if (block_cardinality > BlockSize) { + // insert-or-apply into global map for the remaining elements whose block_cardinality + // exceeds the cardinality threshold. + if (block_cardinality > BlockSize) { + idx += loop_stride; + while (idx < n) { + value_type const& insert_pair = *(first + idx); + ref.insert_or_apply(insert_pair, op); idx += loop_stride; - while (idx < n) { - value_type const& insert_pair = *(first + idx); - ref.insert_or_apply(insert_pair, op); - idx += loop_stride; - } } } } diff --git a/include/cuco/detail/static_map/static_map.inl b/include/cuco/detail/static_map/static_map.inl index ee2db7342..8f7b58ab1 100644 --- a/include/cuco/detail/static_map/static_map.inl +++ b/include/cuco/detail/static_map/static_map.inl @@ -321,43 +321,52 @@ void static_map((1.0 / load_factor) * shared_map_num_elements); - - using extent_type = cuco::extent; - using shared_map_type = cuco::static_map>; - using shared_map_ref_type = typename shared_map_type::ref_type<>; - auto constexpr window_extent = cuco::make_window_extent(extent_type{}); - - using ref_type = decltype(ref(op::insert_or_apply)); - - auto insert_or_apply_shmem_fn_ptr = static_map_ns::detail:: - insert_or_apply_shmem; - - int32_t const max_op_grid_size = - cuco::detail::max_occupancy_grid_size(shmem_block_size, insert_or_apply_shmem_fn_ptr); - - auto const shmem_grid_size = std::min(default_grid_size, max_op_grid_size); - auto const num_loops_per_thread = num / (shmem_grid_size * shmem_block_size); - - // use shared_memory only if each thread has atleast 2 elements to process - if (num_loops_per_thread > 2) { - static_map_ns::detail::insert_or_apply_shmem - <<>>( - first, num, op, ref(op::insert_or_apply), window_extent); + int32_t const default_grid_size = cuco::detail::grid_size(num, cg_size); + + if constexpr (cg_size == 1) { + int32_t constexpr shmem_block_size = 1024; + shmem_size_type constexpr cardinality_threshold = shmem_block_size; + shmem_size_type constexpr shared_map_num_elements = cardinality_threshold + shmem_block_size; + float constexpr load_factor = 0.7; + shmem_size_type constexpr shared_map_size = + static_cast((1.0 / load_factor) * shared_map_num_elements); + + using extent_type = cuco::extent; + using shared_map_type = cuco::static_map>; + using shared_map_ref_type = typename shared_map_type::ref_type<>; + auto constexpr window_extent = cuco::make_window_extent(extent_type{}); + + using ref_type = decltype(ref(op::insert_or_apply)); + + auto insert_or_apply_shmem_fn_ptr = static_map_ns::detail:: + insert_or_apply_shmem; + + int32_t const max_op_grid_size = + cuco::detail::max_occupancy_grid_size(shmem_block_size, insert_or_apply_shmem_fn_ptr); + + int32_t const shmem_default_grid_size = + cuco::detail::grid_size(num, cg_size, cuco::detail::default_stride(), shmem_block_size); + + auto const shmem_grid_size = std::min(shmem_default_grid_size, max_op_grid_size); + auto const num_elements_per_thread = num / (shmem_grid_size * shmem_block_size); + + // use shared_memory only if each thread has atleast 3 elements to process + if (num_elements_per_thread > 2) { + static_map_ns::detail::insert_or_apply_shmem + <<>>( + first, num, op, ref(op::insert_or_apply), window_extent); + } else { + static_map_ns::detail::insert_or_apply + <<>>( + first, num, op, ref(op::insert_or_apply)); + } } else { static_map_ns::detail::insert_or_apply <<>>(