Skip to content

Commit

Permalink
Merge pull request #117 from vyasr/fix/static_map_device_contains
Browse files Browse the repository at this point in the history
Fix device side contains implementation and add tests
  • Loading branch information
jrhemstad authored Nov 11, 2021
2 parents f0eecb2 + 2373d71 commit 7f6f1c2
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 9 deletions.
6 changes: 3 additions & 3 deletions include/cuco/detail/static_map.inl
Original file line number Diff line number Diff line change
Expand Up @@ -441,14 +441,14 @@ static_map<Key, Value, Scope, Allocator>::device_view::find(CG g,
template <typename Key, typename Value, cuda::thread_scope Scope, typename Allocator>
template <typename Hash, typename KeyEqual>
__device__ bool static_map<Key, Value, Scope, Allocator>::device_view::contains(
Key const& k, Hash hash, KeyEqual key_equal) noexcept
Key const& k, Hash hash, KeyEqual key_equal) const noexcept
{
auto current_slot = initial_slot(k, hash);

while (true) {
auto const existing_key = current_slot->first.load(cuda::std::memory_order_relaxed);

if (detail::bitwise_compare(existing_key, empty_key_sentinel_)) { return false; }
if (detail::bitwise_compare(existing_key, this->empty_key_sentinel_)) { return false; }

if (key_equal(existing_key, k)) { return true; }

Expand All @@ -459,7 +459,7 @@ __device__ bool static_map<Key, Value, Scope, Allocator>::device_view::contains(
template <typename Key, typename Value, cuda::thread_scope Scope, typename Allocator>
template <typename CG, typename Hash, typename KeyEqual>
__device__ bool static_map<Key, Value, Scope, Allocator>::device_view::contains(
CG g, Key const& k, Hash hash, KeyEqual key_equal) noexcept
CG g, Key const& k, Hash hash, KeyEqual key_equal) const noexcept
{
auto current_slot = initial_slot(g, k, hash);

Expand Down
10 changes: 4 additions & 6 deletions include/cuco/static_map.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -287,13 +287,11 @@ class static_map {
using const_iterator = pair_atomic_type const*;
using slot_type = slot_type;

private:
pair_atomic_type* slots_{}; ///< Pointer to flat slots storage
std::size_t capacity_{}; ///< Total number of slots
Key empty_key_sentinel_{}; ///< Key value that represents an empty slot
Value empty_value_sentinel_{}; ///< Initial Value of empty slot
pair_atomic_type* slots_{}; ///< Pointer to flat slots storage
std::size_t capacity_{}; ///< Total number of slots

protected:
__host__ __device__ device_view_base(pair_atomic_type* slots,
std::size_t capacity,
Key empty_key_sentinel,
Expand Down Expand Up @@ -961,7 +959,7 @@ class static_map {
typename KeyEqual = thrust::equal_to<key_type>>
__device__ bool contains(Key const& k,
Hash hash = Hash{},
KeyEqual key_equal = KeyEqual{}) noexcept;
KeyEqual key_equal = KeyEqual{}) const noexcept;

/**
* @brief Indicates whether the key `k` was inserted into the map.
Expand Down Expand Up @@ -989,7 +987,7 @@ class static_map {
__device__ bool contains(CG g,
Key const& k,
Hash hash = Hash{},
KeyEqual key_equal = KeyEqual{}) noexcept;
KeyEqual key_equal = KeyEqual{}) const noexcept;
}; // class device_view

/**
Expand Down
12 changes: 12 additions & 0 deletions tests/static_map/static_map_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,18 @@ TEST_CASE("User defined key and value type", "")
none_of(contained.begin(), contained.end(), [] __device__(bool const& b) { return b; }));
}

SECTION("All inserted keys-value pairs should be contained")
{
thrust::device_vector<bool> contained(num_pairs);
map.insert(insert_pairs, insert_pairs + num_pairs, hash_key_pair{}, key_pair_equals{});
auto view = map.get_device_view();
REQUIRE(all_of(insert_pairs,
insert_pairs + num_pairs,
[view] __device__(cuco::pair_type<Key, Value> const& pair) {
return view.contains(pair.first, hash_key_pair{}, key_pair_equals{});
}));
}

SECTION("Inserting unique keys should return insert success.")
{
auto m_view = map.get_device_mutable_view();
Expand Down

0 comments on commit 7f6f1c2

Please sign in to comment.