Skip to content

Commit

Permalink
minor improvements
Browse files Browse the repository at this point in the history
use shared_memory kernel only if `cg_size == 1`.
use `shmem_block_size` when calculating `shmem_grid_size`.
  • Loading branch information
srinivasyadav18 committed Aug 2, 2024
1 parent 2cd20f9 commit 015f945
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 76 deletions.
69 changes: 30 additions & 39 deletions include/cuco/detail/static_map/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ CUCO_KERNEL __launch_bounds__(BlockSize) void insert_or_apply_shmem(
Ref ref,
typename SharedMapRefType::extent_type window_extent)
{
static_assert(CGSize == 1, "use shared_memory kernel only if cg_size == 1");
namespace cg = cooperative_groups;
using Key = typename Ref::key_type;
using Value = typename Ref::mapped_type;
Expand Down Expand Up @@ -156,52 +157,42 @@ CUCO_KERNEL __launch_bounds__(BlockSize) void insert_or_apply_shmem(
block.sync();

while ((idx - thread_idx / CGSize) < n) {
if constexpr (CGSize == 1) {
int32_t inserted = 0;
int32_t local_cardinality = 0;
// insert-or-apply into the shared map first
if (idx < n) {
value_type const& insert_pair = *(first + idx);
inserted = shared_map_ref.insert_or_apply(insert_pair, op);
}
if (idx - warp_thread_idx < n) { // all threads in warp particpate
local_cardinality = cg::reduce(warp, inserted, cg::plus<int32_t>());
}
if (warp_thread_idx == 0) {
block_cardinality.fetch_add(local_cardinality, cuda::memory_order_relaxed);
}
block.sync();
if (block_cardinality > BlockSize) { break; }
} else {
auto const tile = cg::tiled_partition<CGSize>(block);
if (idx < n) {
value_type const& insert_pair = *(first + idx);
ref.insert_or_apply(tile, insert_pair, op);
}
int32_t inserted = 0;
int32_t local_cardinality = 0;
// insert-or-apply into the shared map first
if (idx < n) {
value_type const& insert_pair = *(first + idx);
inserted = shared_map_ref.insert_or_apply(insert_pair, op);
}
if (idx - warp_thread_idx < n) { // all threads in warp particpate
local_cardinality = cg::reduce(warp, inserted, cg::plus<int32_t>());
}
if (warp_thread_idx == 0) {
block_cardinality.fetch_add(local_cardinality, cuda::memory_order_relaxed);
}
block.sync();
if (block_cardinality > BlockSize) { break; }
idx += loop_stride;
}

if constexpr (CGSize == 1) {
// insert-or-apply from shared map to global map
auto window_idx = thread_idx;
while (window_idx < num_windows) {
auto const slot = storage[window_idx][0];
if (not cuco::detail::bitwise_compare(slot.first, ref.empty_key_sentinel())) {
ref.insert_or_apply(slot, op);
}
window_idx += BlockSize;
// insert-or-apply from shared map to global map
auto window_idx = thread_idx;
while (window_idx < num_windows) {
auto const slot = storage[window_idx][0];
if (not cuco::detail::bitwise_compare(slot.first, ref.empty_key_sentinel())) {
ref.insert_or_apply(slot, op);
}
window_idx += BlockSize;
}

// insert-or-apply into global map for the remaining elements whose block_cardinality
// exceeds the cardinality threshold.
if (block_cardinality > BlockSize) {
// insert-or-apply into global map for the remaining elements whose block_cardinality
// exceeds the cardinality threshold.
if (block_cardinality > BlockSize) {
idx += loop_stride;
while (idx < n) {
value_type const& insert_pair = *(first + idx);
ref.insert_or_apply(insert_pair, op);
idx += loop_stride;
while (idx < n) {
value_type const& insert_pair = *(first + idx);
ref.insert_or_apply(insert_pair, op);
idx += loop_stride;
}
}
}
}
Expand Down
83 changes: 46 additions & 37 deletions include/cuco/detail/static_map/static_map.inl
Original file line number Diff line number Diff line change
Expand Up @@ -321,43 +321,52 @@ void static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Stora

using shmem_size_type = int32_t;

int32_t constexpr shmem_block_size = 1024;
int32_t const default_grid_size = cuco::detail::grid_size(num, cg_size);

shmem_size_type constexpr cardinality_threshold = shmem_block_size;
shmem_size_type constexpr shared_map_num_elements = cardinality_threshold + shmem_block_size;
float constexpr load_factor = 0.7;
shmem_size_type constexpr shared_map_size =
static_cast<shmem_size_type>((1.0 / load_factor) * shared_map_num_elements);

using extent_type = cuco::extent<shmem_size_type, shared_map_size>;
using shared_map_type = cuco::static_map<Key,
T,
extent_type,
cuda::thread_scope_block,
KeyEqual,
ProbingScheme,
Allocator,
cuco::storage<1>>;
using shared_map_ref_type = typename shared_map_type::ref_type<>;
auto constexpr window_extent = cuco::make_window_extent<shared_map_ref_type>(extent_type{});

using ref_type = decltype(ref(op::insert_or_apply));

auto insert_or_apply_shmem_fn_ptr = static_map_ns::detail::
insert_or_apply_shmem<cg_size, shmem_block_size, shared_map_ref_type, InputIt, Op, ref_type>;

int32_t const max_op_grid_size =
cuco::detail::max_occupancy_grid_size(shmem_block_size, insert_or_apply_shmem_fn_ptr);

auto const shmem_grid_size = std::min(default_grid_size, max_op_grid_size);
auto const num_loops_per_thread = num / (shmem_grid_size * shmem_block_size);

// use shared_memory only if each thread has atleast 2 elements to process
if (num_loops_per_thread > 2) {
static_map_ns::detail::insert_or_apply_shmem<cg_size, shmem_block_size, shared_map_ref_type>
<<<shmem_grid_size, shmem_block_size, 0, stream.get()>>>(
first, num, op, ref(op::insert_or_apply), window_extent);
int32_t const default_grid_size = cuco::detail::grid_size(num, cg_size);

if constexpr (cg_size == 1) {
int32_t constexpr shmem_block_size = 1024;
shmem_size_type constexpr cardinality_threshold = shmem_block_size;
shmem_size_type constexpr shared_map_num_elements = cardinality_threshold + shmem_block_size;
float constexpr load_factor = 0.7;
shmem_size_type constexpr shared_map_size =
static_cast<shmem_size_type>((1.0 / load_factor) * shared_map_num_elements);

using extent_type = cuco::extent<shmem_size_type, shared_map_size>;
using shared_map_type = cuco::static_map<Key,
T,
extent_type,
cuda::thread_scope_block,
KeyEqual,
ProbingScheme,
Allocator,
cuco::storage<1>>;
using shared_map_ref_type = typename shared_map_type::ref_type<>;
auto constexpr window_extent = cuco::make_window_extent<shared_map_ref_type>(extent_type{});

using ref_type = decltype(ref(op::insert_or_apply));

auto insert_or_apply_shmem_fn_ptr = static_map_ns::detail::
insert_or_apply_shmem<cg_size, shmem_block_size, shared_map_ref_type, InputIt, Op, ref_type>;

int32_t const max_op_grid_size =
cuco::detail::max_occupancy_grid_size(shmem_block_size, insert_or_apply_shmem_fn_ptr);

int32_t const shmem_default_grid_size =
cuco::detail::grid_size(num, cg_size, cuco::detail::default_stride(), shmem_block_size);

auto const shmem_grid_size = std::min(shmem_default_grid_size, max_op_grid_size);
auto const num_elements_per_thread = num / (shmem_grid_size * shmem_block_size);

// use shared_memory only if each thread has atleast 3 elements to process
if (num_elements_per_thread > 2) {
static_map_ns::detail::insert_or_apply_shmem<cg_size, shmem_block_size, shared_map_ref_type>
<<<shmem_grid_size, shmem_block_size, 0, stream.get()>>>(
first, num, op, ref(op::insert_or_apply), window_extent);
} else {
static_map_ns::detail::insert_or_apply<cg_size, cuco::detail::default_block_size()>
<<<default_grid_size, cuco::detail::default_block_size(), 0, stream.get()>>>(
first, num, op, ref(op::insert_or_apply));
}
} else {
static_map_ns::detail::insert_or_apply<cg_size, cuco::detail::default_block_size()>
<<<default_grid_size, cuco::detail::default_block_size(), 0, stream.get()>>>(
Expand Down

0 comments on commit 015f945

Please sign in to comment.