Skip to content

Commit

Permalink
Add consistent for_each APIs for cuco hash tables (NVIDIA#632)
Browse files Browse the repository at this point in the history
This PR adds host and device `for_each` APIs for all cuco hash tables.

---------

Co-authored-by: Daniel Jünger <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Nov 2, 2024
1 parent b29b608 commit 5b4a80e
Show file tree
Hide file tree
Showing 11 changed files with 685 additions and 120 deletions.
7 changes: 3 additions & 4 deletions include/cuco/detail/static_map/static_map_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@
#include <cuco/operator.hpp>

#include <cuda/atomic>
#include <cuda/std/functional>
#include <cuda/std/type_traits>
#include <thrust/tuple.h>
#include <cuda/std/utility>

#include <cooperative_groups.h>

Expand Down Expand Up @@ -1335,7 +1334,7 @@ class operator_impl<
{
// CRTP: cast `this` to the actual ref type
auto const& ref_ = static_cast<ref_type const&>(*this);
ref_.impl_.for_each(key, std::forward<CallbackOp>(callback_op));
ref_.impl_.for_each(key, cuda::std::forward<CallbackOp>(callback_op));
}

/**
Expand Down Expand Up @@ -1363,7 +1362,7 @@ class operator_impl<
{
// CRTP: cast `this` to the actual ref type
auto const& ref_ = static_cast<ref_type const&>(*this);
ref_.impl_.for_each(group, key, std::forward<CallbackOp>(callback_op));
ref_.impl_.for_each(group, key, cuda::std::forward<CallbackOp>(callback_op));
}
};

Expand Down
67 changes: 67 additions & 0 deletions include/cuco/detail/static_multimap/static_multimap.inl
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,73 @@ void static_multimap<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator,
impl_->find_async(first, last, output_begin, ref(op::find), stream);
}

template <class Key,
class T,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename CallbackOp>
void static_multimap<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::for_each(
CallbackOp&& callback_op, cuda::stream_ref stream) const
{
impl_->for_each_async(std::forward<CallbackOp>(callback_op), stream);
stream.wait();
}

template <class Key,
class T,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename CallbackOp>
void static_multimap<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::
for_each_async(CallbackOp&& callback_op, cuda::stream_ref stream) const
{
impl_->for_each_async(std::forward<CallbackOp>(callback_op), stream);
}

template <class Key,
class T,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename InputIt, typename CallbackOp>
void static_multimap<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::for_each(
InputIt first, InputIt last, CallbackOp&& callback_op, cuda::stream_ref stream) const
{
impl_->for_each_async(
first, last, std::forward<CallbackOp>(callback_op), ref(op::for_each), stream);
stream.wait();
}

template <class Key,
class T,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename InputIt, typename CallbackOp>
void static_multimap<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::
for_each_async(InputIt first,
InputIt last,
CallbackOp&& callback_op,
cuda::stream_ref stream) const noexcept
{
impl_->for_each_async(
first, last, std::forward<CallbackOp>(callback_op), ref(op::for_each), stream);
}

template <class Key,
class T,
class Extent,
Expand Down
150 changes: 106 additions & 44 deletions include/cuco/detail/static_multiset/static_multiset.inl
Original file line number Diff line number Diff line change
Expand Up @@ -284,17 +284,12 @@ template <class Key,
class ProbingScheme,
class Allocator,
class Storage>
template <class InputProbeIt, class OutputProbeIt, class OutputMatchIt>
std::pair<OutputProbeIt, OutputMatchIt>
static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::retrieve(
InputProbeIt first,
InputProbeIt last,
OutputProbeIt output_probe,
OutputMatchIt output_match,
cuda::stream_ref stream) const
template <typename CallbackOp>
void static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::for_each(
CallbackOp&& callback_op, cuda::stream_ref stream) const
{
return this->impl_->retrieve(
first, last, output_probe, output_match, this->ref(op::retrieve), stream);
impl_->for_each_async(std::forward<CallbackOp>(callback_op), stream);
stream.wait();
}

template <class Key,
Expand All @@ -304,24 +299,11 @@ template <class Key,
class ProbingScheme,
class Allocator,
class Storage>
template <class InputProbeIt,
class ProbeEqual,
class ProbeHash,
class OutputProbeIt,
class OutputMatchIt>
std::pair<OutputProbeIt, OutputMatchIt>
static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::retrieve(
InputProbeIt first,
InputProbeIt last,
ProbeEqual const& probe_equal,
ProbeHash const& probe_hash,
OutputProbeIt output_probe,
OutputMatchIt output_match,
cuda::stream_ref stream) const
template <typename CallbackOp>
void static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::
for_each_async(CallbackOp&& callback_op, cuda::stream_ref stream) const
{
auto const probe_ref =
this->ref(op::retrieve).rebind_key_eq(probe_equal).rebind_hash_function(probe_hash);
return this->impl_->retrieve(first, last, output_probe, output_match, probe_ref, stream);
impl_->for_each_async(std::forward<CallbackOp>(callback_op), stream);
}

template <class Key,
Expand All @@ -331,24 +313,31 @@ template <class Key,
class ProbingScheme,
class Allocator,
class Storage>
template <class InputProbeIt,
class ProbeEqual,
class ProbeHash,
class OutputProbeIt,
class OutputMatchIt>
std::pair<OutputProbeIt, OutputMatchIt>
static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::retrieve_outer(
InputProbeIt first,
InputProbeIt last,
ProbeEqual const& probe_equal,
ProbeHash const& probe_hash,
OutputProbeIt output_probe,
OutputMatchIt output_match,
cuda::stream_ref stream) const
template <typename InputIt, typename CallbackOp>
void static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::for_each(
InputIt first, InputIt last, CallbackOp&& callback_op, cuda::stream_ref stream) const
{
auto const probe_ref =
this->ref(op::retrieve).rebind_key_eq(probe_equal).rebind_hash_function(probe_hash);
return this->impl_->retrieve_outer(first, last, output_probe, output_match, probe_ref, stream);
impl_->for_each_async(
first, last, std::forward<CallbackOp>(callback_op), ref(op::for_each), stream);
stream.wait();
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename InputIt, typename CallbackOp>
void static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::
for_each_async(InputIt first,
InputIt last,
CallbackOp&& callback_op,
cuda::stream_ref stream) const noexcept
{
impl_->for_each_async(
first, last, std::forward<CallbackOp>(callback_op), ref(op::for_each), stream);
}

template <class Key,
Expand Down Expand Up @@ -412,6 +401,79 @@ static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>
stream);
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <class InputProbeIt, class OutputProbeIt, class OutputMatchIt>
std::pair<OutputProbeIt, OutputMatchIt>
static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::retrieve(
InputProbeIt first,
InputProbeIt last,
OutputProbeIt output_probe,
OutputMatchIt output_match,
cuda::stream_ref stream) const
{
return impl_->retrieve(first, last, output_probe, output_match, this->ref(op::retrieve), stream);
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <class InputProbeIt,
class ProbeEqual,
class ProbeHash,
class OutputProbeIt,
class OutputMatchIt>
std::pair<OutputProbeIt, OutputMatchIt>
static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::retrieve(
InputProbeIt first,
InputProbeIt last,
ProbeEqual const& probe_equal,
ProbeHash const& probe_hash,
OutputProbeIt output_probe,
OutputMatchIt output_match,
cuda::stream_ref stream) const
{
auto const probe_ref =
this->ref(op::retrieve).rebind_key_eq(probe_equal).rebind_hash_function(probe_hash);
return impl_->retrieve(first, last, output_probe, output_match, probe_ref, stream);
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <class InputProbeIt,
class ProbeEqual,
class ProbeHash,
class OutputProbeIt,
class OutputMatchIt>
std::pair<OutputProbeIt, OutputMatchIt>
static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::retrieve_outer(
InputProbeIt first,
InputProbeIt last,
ProbeEqual const& probe_equal,
ProbeHash const& probe_hash,
OutputProbeIt output_probe,
OutputMatchIt output_match,
cuda::stream_ref stream) const
{
auto const probe_ref =
this->ref(op::retrieve).rebind_key_eq(probe_equal).rebind_hash_function(probe_hash);
return impl_->retrieve_outer(first, last, output_probe, output_match, probe_ref, stream);
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
Expand Down
60 changes: 60 additions & 0 deletions include/cuco/detail/static_set/static_set.inl
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,66 @@ void static_set<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>
impl_->find_async(first, last, output_begin, ref(op::find), stream);
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename CallbackOp>
void static_set<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::for_each(
CallbackOp&& callback_op, cuda::stream_ref stream) const
{
impl_->for_each_async(std::forward<CallbackOp>(callback_op), stream);
stream.wait();
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename CallbackOp>
void static_set<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::for_each_async(
CallbackOp&& callback_op, cuda::stream_ref stream) const
{
impl_->for_each_async(std::forward<CallbackOp>(callback_op), stream);
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename InputIt, typename CallbackOp>
void static_set<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::for_each(
InputIt first, InputIt last, CallbackOp&& callback_op, cuda::stream_ref stream) const
{
impl_->for_each_async(
first, last, std::forward<CallbackOp>(callback_op), ref(op::for_each), stream);
stream.wait();
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename InputIt, typename CallbackOp>
void static_set<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::for_each_async(
InputIt first, InputIt last, CallbackOp&& callback_op, cuda::stream_ref stream) const noexcept
{
impl_->for_each_async(
first, last, std::forward<CallbackOp>(callback_op), ref(op::for_each), stream);
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
Expand Down
Loading

0 comments on commit 5b4a80e

Please sign in to comment.