Skip to content

Commit

Permalink
Add host-bulk for_each for static_map (#565)
Browse files Browse the repository at this point in the history
  • Loading branch information
srinivasyadav18 authored Aug 13, 2024
1 parent 99282c0 commit 118fd1f
Show file tree
Hide file tree
Showing 9 changed files with 455 additions and 31 deletions.
40 changes: 40 additions & 0 deletions include/cuco/detail/open_addressing/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,46 @@ CUCO_KERNEL __launch_bounds__(BlockSize) void erase(InputIt first,
}
}

/**
* @brief For each key in the range [first, first + n), applies the function object `callback_op` to
* the copy of all corresponding matches found in the container.
*
* @note The return value of `callback_op`, if any, is ignored.
*
* @tparam CGSize Number of threads in each CG
* @tparam BlockSize Number of threads in each block
* @tparam InputIt Device accessible input iterator whose `value_type` is
* convertible to the `key_type` of the data structure
* @tparam CallbackOp Type of unary callback function object
* @tparam Ref Type of non-owning device ref allowing access to storage
*
* @param first Beginning of the sequence of input elements
* @param n Number of input elements
* @param callback_op Function to call on every matched slot found in the container
* @param ref Non-owning container device ref used to access the slot storage
*/
template <int32_t CGSize, int32_t BlockSize, typename InputIt, typename CallbackOp, typename Ref>
CUCO_KERNEL __launch_bounds__(BlockSize) void for_each_n(InputIt first,
cuco::detail::index_type n,
CallbackOp callback_op,
Ref ref)
{
auto const loop_stride = cuco::detail::grid_stride() / CGSize;
auto idx = cuco::detail::global_thread_id() / CGSize;

while (idx < n) {
typename std::iterator_traits<InputIt>::value_type const& key{*(first + idx)};
if constexpr (CGSize == 1) {
ref.for_each(key, callback_op);
} else {
auto const tile =
cooperative_groups::tiled_partition<CGSize>(cooperative_groups::this_thread_block());
ref.for_each(tile, key, callback_op);
}
idx += loop_stride;
}
}

/**
* @brief Indicates whether the keys in the range `[first, first + n)` are contained in the data
* structure if `pred` of the corresponding stencil returns true.
Expand Down
62 changes: 62 additions & 0 deletions include/cuco/detail/open_addressing/open_addressing_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <cuco/storage.cuh>
#include <cuco/utility/traits.hpp>

#include <cub/device/device_for.cuh>
#include <cub/device/device_select.cuh>
#include <cuda/atomic>
#include <thrust/iterator/constant_iterator.h>
Expand Down Expand Up @@ -681,6 +682,67 @@ class open_addressing_impl {
return output_begin + h_num_out;
}

/**
* @brief Asynchronously applies the given function object `callback_op` to the copy of every
* filled slot in the container
*
* @note The return value of `callback_op`, if any, is ignored.
*
* @tparam CallbackOp Type of unary callback function object
*
* @param callback_op Function to call on every filled slot in the container
* @param stream CUDA stream used for this operation
*/
template <typename CallbackOp>
void for_each_async(CallbackOp&& callback_op, cuda::stream_ref stream) const
{
auto const is_filled = open_addressing_ns::detail::slot_is_filled<has_payload, key_type>{
this->empty_key_sentinel(), this->erased_key_sentinel()};

auto storage_ref = this->storage_ref();
auto const op = [callback_op, is_filled, storage_ref] __device__(auto const window_slots) {
for (auto const slot : window_slots) {
if (is_filled(slot)) { callback_op(slot); }
}
};

CUCO_CUDA_TRY(cub::DeviceFor::ForEachCopyN(
storage_ref.data(), storage_ref.num_windows(), op, stream.get()));
}

/**
* @brief For each key in the range [first, last), asynchronously applies the function object
* `callback_op` to the copy of all corresponding matches found in the container.
*
* @note The return value of `callback_op`, if any, is ignored.
*
* @tparam InputIt Device accessible random access input iterator
* @tparam CallbackOp Type of unary callback function object
* @tparam Ref Type of non-owning device container ref allowing access to storage
*
* @param first Beginning of the sequence of keys
* @param last End of the sequence of keys
* @param callback_op Function to call on every match found in the container
* @param container_ref Non-owning device container ref used to access the slot storage
* @param stream CUDA stream used for this operation
*/
template <typename InputIt, typename CallbackOp, typename Ref>
void for_each_async(InputIt first,
InputIt last,
CallbackOp&& callback_op,
Ref container_ref,
cuda::stream_ref stream) const noexcept
{
auto const num_keys = cuco::detail::distance(first, last);
if (num_keys == 0) { return; }

auto const grid_size = cuco::detail::grid_size(num_keys, cg_size);

detail::for_each_n<cg_size, cuco::detail::default_block_size()>
<<<grid_size, cuco::detail::default_block_size(), 0, stream.get()>>>(
first, num_keys, std::forward<CallbackOp>(callback_op), container_ref);
}

/**
* @brief Gets the number of elements in the container
*
Expand Down
47 changes: 22 additions & 25 deletions include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -966,17 +966,16 @@ class open_addressing_ref_impl {
}

/**
* @brief Executes a callback on every element in the container with key equivalent to the probe
* key.
* @brief For a given key, applies the function object `callback_op` to the copy of all
* corresponding matches found in the container.
*
* @note Passes an un-incrementable input iterator to the element whose key is equivalent to
* `key` to the callback.
* @note The return value of `callback_op`, if any, is ignored.
*
* @tparam ProbeKey Probe key type
* @tparam CallbackOp Unary callback functor or device lambda
* @tparam CallbackOp Type of unary callback function object
*
* @param key The key to search for
* @param callback_op Function to call on every element found
* @param callback_op Function to apply to every matched slot
*/
template <class ProbeKey, class CallbackOp>
__device__ void for_each(ProbeKey const& key, CallbackOp&& callback_op) const noexcept
Expand All @@ -995,7 +994,7 @@ class open_addressing_ref_impl {
return;
}
case detail::equal_result::EQUAL: {
callback_op(const_iterator{&(*(this->storage_ref_.data() + *probing_iter))[i]});
callback_op(window_slots[i]);
continue;
}
default: continue;
Expand All @@ -1006,24 +1005,23 @@ class open_addressing_ref_impl {
}

/**
* @brief Executes a callback on every element in the container with key equivalent to the probe
* key.
*
* @note Passes an un-incrementable input iterator to the element whose key is equivalent to
* `key` to the callback.
* @brief For a given key, applies the function object `callback_op` to the copy of all
* corresponding matches found in the container.
*
* @note This function uses cooperative group semantics, meaning that any thread may call the
* callback if it finds a matching element. If multiple elements are found within the same group,
* each thread with a match will call the callback with its associated element.
*
* @note The return value of `callback_op`, if any, is ignored.
*
* @note Synchronizing `group` within `callback_op` is undefined behavior.
*
* @tparam ProbeKey Probe key type
* @tparam CallbackOp Unary callback functor or device lambda
* @tparam CallbackOp Type of unary callback function object
*
* @param group The Cooperative Group used to perform this operation
* @param key The key to search for
* @param callback_op Function to call on every element found
* @param callback_op Function to apply to every matched slot
*/
template <class ProbeKey, class CallbackOp>
__device__ void for_each(cooperative_groups::thread_block_tile<cg_size> const& group,
Expand All @@ -1045,7 +1043,7 @@ class open_addressing_ref_impl {
continue;
}
case detail::equal_result::EQUAL: {
callback_op(const_iterator{&(*(this->storage_ref_.data() + *probing_iter))[i]});
callback_op(window_slots[i]);
continue;
}
default: {
Expand All @@ -1060,31 +1058,30 @@ class open_addressing_ref_impl {
}

/**
* @brief Executes a callback on every element in the container with key equivalent to the probe
* key and can additionally perform work that requires synchronizing the Cooperative Group
* performing this operation.
*
* @note Passes an un-incrementable input iterator to the element whose key is equivalent to
* `key` to the callback.
* @brief Applies the function object `callback_op` to the copy of every slot in the container
* with key equivalent to the probe key and can additionally perform work that requires
* synchronizing the Cooperative Group performing this operation.
*
* @note This function uses cooperative group semantics, meaning that any thread may call the
* callback if it finds a matching element. If multiple elements are found within the same group,
* each thread with a match will call the callback with its associated element.
*
* @note Synchronizing `group` within `callback_op` is undefined behavior.
*
* @note The return value of `callback_op`, if any, is ignored.
*
* @note The `sync_op` function can be used to perform work that requires synchronizing threads in
* `group` inbetween probing steps, where the number of probing steps performed between
* synchronization points is capped by `window_size * cg_size`. The functor will be called right
* after the current probing window has been traversed.
*
* @tparam ProbeKey Probe key type
* @tparam CallbackOp Unary callback functor or device lambda
* @tparam SyncOp Functor or device lambda which accepts the current `group` object
* @tparam CallbackOp Type of unary callback function object
* @tparam SyncOp Type of function object which accepts the current `group` object
*
* @param group The Cooperative Group used to perform this operation
* @param key The key to search for
* @param callback_op Function to call on every element found
* @param callback_op Function to apply to every matched slot
* @param sync_op Function that is allowed to synchronize `group` inbetween probing windows
*/
template <class ProbeKey, class CallbackOp, class SyncOp>
Expand All @@ -1108,7 +1105,7 @@ class open_addressing_ref_impl {
continue;
}
case detail::equal_result::EQUAL: {
callback_op(const_iterator{&(*(this->storage_ref_.data() + *probing_iter))[i]});
callback_op(window_slots[i]);
continue;
}
default: {
Expand Down
64 changes: 64 additions & 0 deletions include/cuco/detail/static_map/static_map.inl
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,70 @@ void static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Stora
impl_->find_async(first, last, output_begin, ref(op::find), stream);
}

template <class Key,
class T,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename CallbackOp>
void static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::for_each(
CallbackOp&& callback_op, cuda::stream_ref stream) const
{
impl_->for_each_async(std::forward<CallbackOp>(callback_op), stream);
stream.wait();
}

template <class Key,
class T,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename CallbackOp>
void static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::for_each_async(
CallbackOp&& callback_op, cuda::stream_ref stream) const
{
impl_->for_each_async(std::forward<CallbackOp>(callback_op), stream);
}

template <class Key,
class T,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename InputIt, typename CallbackOp>
void static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::for_each(
InputIt first, InputIt last, CallbackOp&& callback_op, cuda::stream_ref stream) const
{
impl_->for_each_async(
first, last, std::forward<CallbackOp>(callback_op), ref(op::for_each), stream);
stream.wait();
}

template <class Key,
class T,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename InputIt, typename CallbackOp>
void static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::for_each_async(
InputIt first, InputIt last, CallbackOp&& callback_op, cuda::stream_ref stream) const noexcept
{
impl_->for_each_async(
first, last, std::forward<CallbackOp>(callback_op), ref(op::for_each), stream);
}

template <class Key,
class T,
class Extent,
Expand Down
70 changes: 70 additions & 0 deletions include/cuco/detail/static_map/static_map_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -1249,5 +1249,75 @@ class operator_impl<
}
};

template <typename Key,
typename T,
cuda::thread_scope Scope,
typename KeyEqual,
typename ProbingScheme,
typename StorageRef,
typename... Operators>
class operator_impl<
op::for_each_tag,
static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>> {
using base_type = static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef>;
using ref_type = static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>;
using key_type = typename base_type::key_type;
using value_type = typename base_type::value_type;
using iterator = typename base_type::iterator;
using const_iterator = typename base_type::const_iterator;

static constexpr auto cg_size = base_type::cg_size;
static constexpr auto window_size = base_type::window_size;

public:
/**
* @brief For a given key, applies the function object `callback_op` to its match found in the
* container.
*
* @note The return value of `callback_op`, if any, is ignored.
*
* @tparam ProbeKey Probe key type
* @tparam CallbackOp Type of unary callback function object
*
* @param key The key to search for
* @param callback_op Function to apply to the copy of the matched key-value pair
*/
template <class ProbeKey, class CallbackOp>
__device__ void for_each(ProbeKey const& key, CallbackOp&& callback_op) const noexcept
{
// CRTP: cast `this` to the actual ref type
auto const& ref_ = static_cast<ref_type const&>(*this);
ref_.impl_.for_each(key, std::forward<CallbackOp>(callback_op));
}

/**
* @brief For a given key, applies the function object `callback_op` to its match found in the
* container.
*
* @note This function uses cooperative group semantics, meaning that any thread may call the
* callback if it finds a matching key-value pair.
*
* @note The return value of `callback_op`, if any, is ignored.
*
* @note Synchronizing `group` within `callback_op` is undefined behavior.
*
* @tparam ProbeKey Probe key type
* @tparam CallbackOp Type of unary callback function object
*
* @param group The Cooperative Group used to perform this operation
* @param key The key to search for
* @param callback_op Function to apply to the copy of the matched key-value pair
*/
template <class ProbeKey, class CallbackOp>
__device__ void for_each(cooperative_groups::thread_block_tile<cg_size> const& group,
ProbeKey const& key,
CallbackOp&& callback_op) const noexcept
{
// CRTP: cast `this` to the actual ref type
auto const& ref_ = static_cast<ref_type const&>(*this);
ref_.impl_.for_each(group, key, std::forward<CallbackOp>(callback_op));
}
};

} // namespace detail
} // namespace cuco
Loading

0 comments on commit 118fd1f

Please sign in to comment.