Skip to content

Commit

Permalink
Proclaim more return types for CCCL 2.2.0 compatibility. (#405)
Browse files Browse the repository at this point in the history
This PR adds more uses of `cuda::proclaim_return_type` so that #404 can
pass CI. This PR can be merged immediately, but #404 needs to wait a bit
longer.
  • Loading branch information
bdice authored Dec 8, 2023
1 parent 7fae640 commit 368cda3
Show file tree
Hide file tree
Showing 9 changed files with 79 additions and 47 deletions.
8 changes: 6 additions & 2 deletions examples/static_map/custom_type_example.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
#include <thrust/logical.h>
#include <thrust/transform.h>

#include <cuda/functional>

// User-defined key type
#if !defined(CUCO_HAS_INDEPENDENT_THREADS)
struct custom_key_type {
Expand Down Expand Up @@ -88,7 +90,8 @@ int main(void)
// Create an iterator of input key/value pairs
auto pairs_begin = thrust::make_transform_iterator(
thrust::make_counting_iterator<int32_t>(0),
[] __device__(auto i) { return cuco::make_pair(custom_key_type{i}, custom_value_type{i}); });
cuda::proclaim_return_type<cuco::pair<custom_key_type, custom_value_type>>(
[] __device__(auto i) { return cuco::make_pair(custom_key_type{i}, custom_value_type{i}); }));

// Construct a map with 100,000 slots using the given empty key/value sentinels. Note the
// capacity is chosen knowing we will insert 80,000 keys, for an load factor of 80%.
Expand All @@ -101,7 +104,8 @@ int main(void)
// Reproduce inserted keys
auto insert_keys =
thrust::make_transform_iterator(thrust::make_counting_iterator<int32_t>(0),
[] __device__(auto i) { return custom_key_type{i}; });
cuda::proclaim_return_type<custom_key_type>(
[] __device__(auto i) { return custom_key_type{i}; }));

thrust::device_vector<bool> contained(num_pairs);

Expand Down
9 changes: 6 additions & 3 deletions tests/dynamic_map/unique_sequence_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

#include <catch2/catch_template_test_macros.hpp>

#include <cuda/functional>

TEMPLATE_TEST_CASE_SIG("Unique sequence of keys",
"",
((typename Key, typename Value), Key, Value),
Expand All @@ -48,9 +50,10 @@ TEMPLATE_TEST_CASE_SIG("Unique sequence of keys",
thrust::sequence(thrust::device, d_keys.begin(), d_keys.end());
thrust::sequence(thrust::device, d_values.begin(), d_values.end());

auto pairs_begin =
thrust::make_transform_iterator(thrust::make_counting_iterator<int>(0),
[] __device__(auto i) { return cuco::pair<Key, Value>(i, i); });
auto pairs_begin = thrust::make_transform_iterator(
thrust::make_counting_iterator<int>(0),
cuda::proclaim_return_type<cuco::pair<Key, Value>>(
[] __device__(auto i) { return cuco::pair<Key, Value>(i, i); }));

thrust::device_vector<Value> d_results(num_keys);
thrust::device_vector<bool> d_contained(num_keys);
Expand Down
3 changes: 2 additions & 1 deletion tests/static_map/heterogeneous_lookup_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ TEMPLATE_TEST_CASE_SIG("Heterogeneous lookup",

auto insert_pairs = thrust::make_transform_iterator(
thrust::counting_iterator<int>(0),
[] __device__(auto i) { return cuco::pair<InsertKey, Value>(i, i); });
cuda::proclaim_return_type<cuco::pair<InsertKey, Value>>(
[] __device__(auto i) { return cuco::pair<InsertKey, Value>(i, i); }));
auto probe_keys = thrust::make_transform_iterator(
thrust::counting_iterator<int>(0),
cuda::proclaim_return_type<ProbeKey>([] __device__(auto i) { return ProbeKey{i}; }));
Expand Down
3 changes: 2 additions & 1 deletion tests/static_map/insert_or_assign_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ __inline__ void test_insert_or_assign(Map& map, size_type num_keys)
// Query pairs have the same keys but different payloads
auto query_pairs_begin = thrust::make_transform_iterator(
thrust::counting_iterator<size_type>(0),
[] __device__(auto i) { return cuco::pair<Key, Value>(i, i * 2); });
cuda::proclaim_return_type<cuco::pair<Key, Value>>(
[] __device__(auto i) { return cuco::pair<Key, Value>(i, i * 2); }));

map.insert_or_assign(query_pairs_begin, query_pairs_begin + num_keys);

Expand Down
51 changes: 29 additions & 22 deletions tests/static_multimap/custom_pair_retrieve_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@

#include <catch2/catch_template_test_macros.hpp>

#include <cuda/functional>

#include <cooperative_groups.h>

// Custom pair equal
Expand Down Expand Up @@ -93,9 +95,9 @@ void test_non_shmem_pair_retrieve(Map& map, std::size_t const num_pairs)
thrust::counting_iterator<int>(0),
thrust::counting_iterator<int>(num_pairs),
d_pairs.begin(),
[] __device__(auto i) {
cuda::proclaim_return_type<cuco::pair<Key, Value>>([] __device__(auto i) {
return cuco::pair<Key, Value>{i / 2, i};
});
}));

auto pair_begin = d_pairs.begin();

Expand All @@ -106,15 +108,17 @@ void test_non_shmem_pair_retrieve(Map& map, std::size_t const num_pairs)
thrust::counting_iterator<int>(0),
thrust::counting_iterator<int>(num_pairs),
pair_begin,
[] __device__(auto i) {
cuda::proclaim_return_type<cuco::pair<Key, Value>>([] __device__(auto i) {
return cuco::pair<Key, Value>{i, i};
});
}));

// create an array of prefix sum
thrust::device_vector<int> d_scan(num_pairs);
auto count_begin = thrust::make_transform_iterator(
thrust::make_counting_iterator<int>(0),
[num_pairs] __device__(auto i) { return i < (num_pairs / 2) ? 2 : 1; });
auto count_begin =
thrust::make_transform_iterator(thrust::make_counting_iterator<int>(0),
cuda::proclaim_return_type<int>([num_pairs] __device__(auto i) {
return i < (num_pairs / 2) ? 2 : 1;
}));
thrust::exclusive_scan(thrust::device, count_begin, count_begin + num_pairs, d_scan.begin(), 0);

auto constexpr gold_size = 300;
Expand Down Expand Up @@ -151,21 +155,24 @@ void test_non_shmem_pair_retrieve(Map& map, std::size_t const num_pairs)
thrust::sort(thrust::device, contained_vals.begin(), contained_vals.end());

// set gold references
auto gold_probe = thrust::make_transform_iterator(thrust::make_counting_iterator<int>(0),
[num_pairs] __device__(auto i) {
if (i < num_pairs) { return i / 2; }
return i - (int(num_pairs) / 2);
});
auto gold_contained_key = thrust::make_transform_iterator(thrust::make_counting_iterator<int>(0),
[num_pairs] __device__(auto i) {
if (i < num_pairs / 2) { return -1; }
return (i - (int(num_pairs) / 2)) / 2;
});
auto gold_contained_val = thrust::make_transform_iterator(thrust::make_counting_iterator<int>(0),
[num_pairs] __device__(auto i) {
if (i < num_pairs / 2) { return -1; }
return i - (int(num_pairs) / 2);
});
auto gold_probe =
thrust::make_transform_iterator(thrust::make_counting_iterator<int>(0),
cuda::proclaim_return_type<int>([num_pairs] __device__(auto i) {
if (i < num_pairs) { return i / 2; }
return i - (int(num_pairs) / 2);
}));
auto gold_contained_key =
thrust::make_transform_iterator(thrust::make_counting_iterator<int>(0),
cuda::proclaim_return_type<int>([num_pairs] __device__(auto i) {
if (i < num_pairs / 2) { return -1; }
return (i - (int(num_pairs) / 2)) / 2;
}));
auto gold_contained_val =
thrust::make_transform_iterator(thrust::make_counting_iterator<int>(0),
cuda::proclaim_return_type<int>([num_pairs] __device__(auto i) {
if (i < num_pairs / 2) { return -1; }
return i - (int(num_pairs) / 2);
}));

auto key_equal = thrust::equal_to<Key>{};
auto value_equal = thrust::equal_to<Value>{};
Expand Down
14 changes: 9 additions & 5 deletions tests/static_multimap/heterogeneous_lookup_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

#include <catch2/catch_template_test_macros.hpp>

#include <cuda/functional>

#include <tuple>

// insert key type
Expand Down Expand Up @@ -103,11 +105,13 @@ TEMPLATE_TEST_CASE("Heterogeneous lookup",
cuco::linear_probing<1, custom_hasher>>
map{capacity, cuco::empty_key<Key>{sentinel_key}, cuco::empty_value<Value>{sentinel_value}};

auto insert_pairs =
thrust::make_transform_iterator(thrust::counting_iterator<int>(0),
[] __device__(auto i) { return cuco::pair<Key, Value>(i, i); });
auto probe_keys = thrust::make_transform_iterator(thrust::counting_iterator<int>(0),
[] __device__(auto i) { return ProbeKey(i); });
auto insert_pairs = thrust::make_transform_iterator(
thrust::counting_iterator<int>(0),
cuda::proclaim_return_type<cuco::pair<Key, Value>>(
[] __device__(auto i) { return cuco::pair<Key, Value>(i, i); }));
auto probe_keys = thrust::make_transform_iterator(
thrust::counting_iterator<int>(0),
cuda::proclaim_return_type<ProbeKey>([] __device__(auto i) { return ProbeKey(i); }));

SECTION("All inserted keys-value pairs should be contained")
{
Expand Down
10 changes: 7 additions & 3 deletions tests/static_set/heterogeneous_lookup_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

#include <catch2/catch_template_test_macros.hpp>

#include <cuda/functional>

#include <tuple>

// insert key type
Expand Down Expand Up @@ -99,9 +101,11 @@ TEMPLATE_TEST_CASE_SIG(
capacity, cuco::empty_key<Key>{sentinel_key}, custom_key_equal{}, probe};

auto insert_keys = thrust::make_transform_iterator(
thrust::counting_iterator<int>(0), [] __device__(auto i) { return InsertKey(i); });
auto probe_keys = thrust::make_transform_iterator(thrust::counting_iterator<int>(0),
[] __device__(auto i) { return ProbeKey(i); });
thrust::counting_iterator<int>(0),
cuda::proclaim_return_type<InsertKey>([] __device__(auto i) { return InsertKey(i); }));
auto probe_keys = thrust::make_transform_iterator(
thrust::counting_iterator<int>(0),
cuda::proclaim_return_type<ProbeKey>([] __device__(auto i) { return ProbeKey(i); }));

SECTION("All inserted keys should be contained")
{
Expand Down
7 changes: 5 additions & 2 deletions tests/static_set/insert_and_find_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

#include <catch2/catch_template_test_macros.hpp>

#include <cuda/functional>

template <typename Set>
__inline__ void test_insert_and_find(Set& set, std::size_t num_keys)
{
Expand All @@ -34,8 +36,9 @@ __inline__ void test_insert_and_find(Set& set, std::size_t num_keys)
if constexpr (cg_size == 1) {
return thrust::counting_iterator<Key>(0);
} else {
return thrust::make_transform_iterator(thrust::counting_iterator<Key>(0),
[] __device__(auto i) { return i / cg_size; });
return thrust::make_transform_iterator(
thrust::counting_iterator<Key>(0),
cuda::proclaim_return_type<Key>([] __device__(auto i) { return i / cg_size; }));
}
}();
auto const keys_end = [&]() {
Expand Down
21 changes: 13 additions & 8 deletions tests/static_set/unique_sequence_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

#include <catch2/catch_template_test_macros.hpp>

#include <cuda/functional>

using size_type = int32_t;

template <typename Set>
Expand All @@ -43,8 +45,10 @@ __inline__ void test_unique_sequence(Set& set, size_type num_keys)
auto keys_begin = d_keys.begin();
thrust::device_vector<bool> d_contained(num_keys);

auto zip_equal = [] __device__(auto const& p) { return thrust::get<0>(p) == thrust::get<1>(p); };
auto is_even = [] __device__(auto const& i) { return i % 2 == 0; };
auto zip_equal = cuda::proclaim_return_type<bool>(
[] __device__(auto const& p) { return thrust::get<0>(p) == thrust::get<1>(p); });
auto is_even =
cuda::proclaim_return_type<bool>([] __device__(auto const& i) { return i % 2 == 0; });

SECTION("Non-inserted keys should not be contained.")
{
Expand Down Expand Up @@ -73,12 +77,13 @@ __inline__ void test_unique_sequence(Set& set, size_type num_keys)
REQUIRE(set.size() == num_keys / 2);

set.contains(keys_begin, keys_begin + num_keys, d_contained.begin());
REQUIRE(cuco::test::equal(d_contained.begin(),
d_contained.end(),
thrust::counting_iterator<std::size_t>(0),
[] __device__(auto const& idx_contained, auto const& idx) {
return ((idx % 2) == 0) == idx_contained;
}));
REQUIRE(cuco::test::equal(
d_contained.begin(),
d_contained.end(),
thrust::counting_iterator<std::size_t>(0),
cuda::proclaim_return_type<bool>([] __device__(auto const& idx_contained, auto const& idx) {
return ((idx % 2) == 0) == idx_contained;
})));
}

set.insert(keys_begin, keys_begin + num_keys);
Expand Down

0 comments on commit 368cda3

Please sign in to comment.