Skip to content

Commit

Permalink
Merge pull request #118 from vyasr/feature/static_map_insert_if
Browse files Browse the repository at this point in the history
Add static_map::insert_if.
  • Loading branch information
jrhemstad authored Nov 19, 2021
2 parents 7f6f1c2 + 1af02fa commit bb8c34d
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 0 deletions.
37 changes: 37 additions & 0 deletions include/cuco/detail/static_map.inl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,43 @@ void static_map<Key, Value, Scope, Allocator>::insert(InputIt first,
size_ += h_num_successes;
}

template <typename Key, typename Value, cuda::thread_scope Scope, typename Allocator>
template <typename InputIt,
typename StencilIt,
typename Predicate,
typename Hash,
typename KeyEqual>
void static_map<Key, Value, Scope, Allocator>::insert_if(InputIt first,
InputIt last,
StencilIt stencil,
Predicate pred,
Hash hash,
KeyEqual key_equal,
cudaStream_t stream)
{
auto num_keys = std::distance(first, last);
if (num_keys == 0) { return; }

auto constexpr block_size = 128;
auto constexpr stride = 1;
auto constexpr tile_size = 4;
auto const grid_size = (tile_size * num_keys + stride * block_size - 1) / (stride * block_size);
auto view = get_device_mutable_view();

// TODO: memset an atomic variable is unsafe
static_assert(sizeof(std::size_t) == sizeof(atomic_ctr_type));
CUCO_CUDA_TRY(cudaMemsetAsync(num_successes_, 0, sizeof(atomic_ctr_type), stream));
std::size_t h_num_successes;

detail::insert_if_n<block_size, tile_size><<<grid_size, block_size, 0, stream>>>(
first, num_keys, num_successes_, view, stencil, pred, hash, key_equal);
CUCO_CUDA_TRY(cudaMemcpyAsync(
&h_num_successes, num_successes_, sizeof(atomic_ctr_type), cudaMemcpyDeviceToHost, stream));
CUCO_CUDA_TRY(cudaStreamSynchronize(stream));

size_ += h_num_successes;
}

template <typename Key, typename Value, cuda::thread_scope Scope, typename Allocator>
template <typename InputIt, typename OutputIt, typename Hash, typename KeyEqual>
void static_map<Key, Value, Scope, Allocator>::find(
Expand Down
70 changes: 70 additions & 0 deletions include/cuco/detail/static_map_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,76 @@ __global__ void insert(
if (threadIdx.x == 0) { *num_successes += block_num_successes; }
}

/**
* @brief Inserts key/value pairs in the range `[first, first + n)` if `pred` of the
* corresponding stencil returns true.
*
* If multiple keys in `[first, last)` compare equal, it is unspecified which
* element is inserted.
*
* @tparam block_size The size of the thread block
* @tparam tile_size The number of threads in the Cooperative Groups used to perform insert
* @tparam InputIt Device accessible input iterator whose `value_type` is
* convertible to the map's `value_type`
* @tparam atomicT Type of atomic storage
* @tparam viewT Type of device view allowing access of hash map storage
* @tparam StencilIt Device accessible random access iterator whose value_type is
* convertible to Predicate's argument type
* @tparam Predicate Unary predicate callable whose return type must be convertible to `bool`
* and argument type is convertible from `std::iterator_traits<StencilIt>::value_type`
* @tparam Hash Unary callable type
* @tparam KeyEqual Binary callable type
* @param first Beginning of the sequence of key/value pairs
* @param n Number of elements to insert
* @param num_successes The number of successfully inserted key/value pairs
* @param view Mutable device view used to access the hash map's slot storage
* @param stencil Beginning of the stencil sequence
* @param pred Predicate to test on every element in the range `[s, s + n)`
* @param hash The unary function to apply to hash each key
* @param key_equal The binary function used to compare two keys for equality
*/
template <std::size_t block_size,
uint32_t tile_size,
typename InputIt,
typename atomicT,
typename viewT,
typename StencilIt,
typename Predicate,
typename Hash,
typename KeyEqual>
__global__ void insert_if_n(InputIt first,
std::size_t n,
atomicT* num_successes,
viewT view,
StencilIt stencil,
Predicate pred,
Hash hash,
KeyEqual key_equal)
{
typedef cub::BlockReduce<std::size_t, block_size> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
std::size_t thread_num_successes = 0;

auto tile = cg::tiled_partition<tile_size>(cg::this_thread_block());
auto tid = block_size * blockIdx.x + threadIdx.x;
auto i = tid / tile_size;

while (i < n) {
if (pred(*(stencil + i))) {
typename viewT::value_type const insert_pair{*(first + i)};
if (view.insert(tile, insert_pair, hash, key_equal)) { thread_num_successes++; }
}
i += (gridDim.x * block_size) / tile_size;
}

// compute number of successfully inserted elements for each block
// and atomically add to the grand total
std::size_t block_num_successes = BlockReduce(temp_storage).Sum(thread_num_successes);
if (threadIdx.x == 0) {
num_successes->fetch_add(block_num_successes, cuda::std::memory_order_relaxed);
}
}

/**
* @brief Finds the values corresponding to all keys in the range `[first, last)`.
*
Expand Down
36 changes: 36 additions & 0 deletions include/cuco/static_map.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,42 @@ class static_map {
typename KeyEqual = thrust::equal_to<key_type>>
void insert(InputIt first, InputIt last, Hash hash = Hash{}, KeyEqual key_equal = KeyEqual{});

/**
* @brief Inserts key/value pairs in the range `[first, last)` if `pred`
* of the corresponding stencil returns true.
*
* The key/value pair `*(first + i)` is inserted if `pred( *(stencil + i) )` returns true.
*
* @tparam InputIt Device accessible random access iterator whose `value_type` is
* convertible to the map's `value_type`
* @tparam StencilIt Device accessible random access iterator whose value_type is
* convertible to Predicate's argument type
* @tparam Predicate Unary predicate callable whose return type must be convertible to `bool` and
* argument type is convertible from `std::iterator_traits<StencilIt>::value_type`.
* @tparam Hash Unary callable type
* @tparam KeyEqual Binary callable type
* @param first Beginning of the sequence of key/value pairs
* @param last End of the sequence of key/value pairs
* @param stencil Beginning of the stencil sequence
* @param pred Predicate to test on every element in the range `[stencil, stencil +
* std::distance(first, last))`
* @param hash The unary function to hash each key
* @param key_equal The binary function to compare two keys for equality
* @param stream CUDA stream used for insert
*/
template <typename InputIt,
typename StencilIt,
typename Predicate,
typename Hash = cuco::detail::MurmurHash3_32<key_type>,
typename KeyEqual = thrust::equal_to<key_type>>
void insert_if(InputIt first,
InputIt last,
StencilIt stencil,
Predicate pred,
Hash hash = Hash{},
KeyEqual key_equal = KeyEqual{},
cudaStream_t stream = 0);

/**
* @brief Finds the values corresponding to all keys in the range `[first, last)`.
*
Expand Down
25 changes: 25 additions & 0 deletions tests/static_map/static_map_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,31 @@ TEST_CASE("User defined key and value type", "")
REQUIRE(all_of(contained.begin(), contained.end(), [] __device__(bool const& b) { return b; }));
}

SECTION("All conditionally inserted keys-value pairs should be contained")
{
thrust::device_vector<bool> contained(num_pairs);
map.insert_if(
insert_pairs,
insert_pairs + num_pairs,
thrust::counting_iterator<int>(0),
[] __device__(auto const& key) { return (key % 2) == 0; },
hash_key_pair{},
key_pair_equals{});
map.contains(insert_keys.begin(),
insert_keys.end(),
contained.begin(),
hash_key_pair{},
key_pair_equals{});

REQUIRE(thrust::equal(thrust::device,
contained.begin(),
contained.end(),
thrust::counting_iterator<int>(0),
[] __device__(auto const& idx_contained, auto const& idx) {
return ((idx % 2) == 0) == idx_contained;
}));
}

SECTION("Non-inserted keys-value pairs should not be contained")
{
thrust::device_vector<bool> contained(num_pairs);
Expand Down

0 comments on commit bb8c34d

Please sign in to comment.