Skip to content

Commit

Permalink
Fix uniform random key generation
Browse files Browse the repository at this point in the history
  • Loading branch information
sleeepyjack committed Oct 30, 2024
1 parent 494f321 commit 5159459
Showing 1 changed file with 23 additions and 14 deletions.
37 changes: 23 additions & 14 deletions include/cuco/utility/key_generator.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
#include <cuco/detail/pair/helpers.cuh>
#include <cuco/detail/utility/strong_type.cuh>

#include <cuda/functional>
#include <cuda/std/cmath>
#include <cuda/std/functional> // TODO include <cuda/std/algorithm> instead once available
#include <cuda/std/limits>
#include <cuda/std/span>
#include <thrust/device_vector.h>
Expand Down Expand Up @@ -80,6 +81,7 @@ struct gaussian : public cuco::detail::strong_type<double> {
} // namespace distribution

namespace detail {

/**
* @brief Generate uniform functor
*
Expand All @@ -94,9 +96,10 @@ struct generate_uniform_fn {
*
* @param num Number of elements to generate
* @param dist Random number distribution
* @param seed Random seed
*/
__host__ __device__ constexpr generate_uniform_fn(std::size_t num, Dist dist)
: num_{num}, dist_{dist}
__host__ __device__ constexpr generate_uniform_fn(std::size_t num, Dist dist, std::size_t seed)
: num_{num}, dist_{dist}, seed_{seed}
{
}

Expand All @@ -107,16 +110,23 @@ struct generate_uniform_fn {
*
* @return A resulting random number
*/
__host__ __device__ constexpr T operator()(std::size_t seed) const noexcept
__host__ __device__ constexpr T operator()(std::size_t idx) const noexcept
{
RNG rng;
thrust::uniform_int_distribution<T> uniform_dist{1, static_cast<T>(num_ / dist_.value)};
rng.seed(seed);
// Improved seeding using a linear congruential generator
rng.seed(seed_ + idx * 1664525ull + 1013904223ull);
// Calculate number of unique keys
auto num_unique_keys = cuda::std::max<size_t>(
1ull,
static_cast<size_t>(
cuda::std::ceil(static_cast<double>(num_) / static_cast<double>(dist_.value))));
thrust::uniform_int_distribution<T> uniform_dist{0, static_cast<T>(num_unique_keys - 1)};
return uniform_dist(rng);
}

std::size_t num_; ///< Number of elements to generate
Dist dist_; ///< Random number distribution
std::size_t num_;
Dist dist_;
std::size_t seed_;
};

/**
Expand Down Expand Up @@ -270,18 +280,17 @@ class key_generator {
using value_type = typename std::iterator_traits<OutputIt>::value_type;

if constexpr (std::is_same_v<Dist, distribution::unique>) {
thrust::sequence(exec_policy, out_begin, out_end, 0);
thrust::sequence(exec_policy, out_begin, out_end, value_type{0});
thrust::shuffle(exec_policy, out_begin, out_end, this->rng_);
} else if constexpr (std::is_same_v<Dist, distribution::uniform>) {
size_t num_keys = thrust::distance(out_begin, out_end);

thrust::counting_iterator<size_t> seeds(this->rng_());
size_t seed = this->rng_();

thrust::transform(exec_policy,
seeds,
seeds + num_keys,
thrust::make_counting_iterator<size_t>(0),
thrust::make_counting_iterator<size_t>(num_keys),
out_begin,
detail::generate_uniform_fn<value_type, Dist, RNG>{num_keys, dist});
detail::generate_uniform_fn<value_type, Dist, RNG>{num_keys, dist, seed});
} else if constexpr (std::is_same_v<Dist, distribution::gaussian>) {
size_t num_keys = thrust::distance(out_begin, out_end);

Expand Down

0 comments on commit 5159459

Please sign in to comment.