diff --git a/cpp/include/raft/cluster/detail/kmeans.cuh b/cpp/include/raft/cluster/detail/kmeans.cuh index 6cfa3156c9..2d3481b4e1 100644 --- a/cpp/include/raft/cluster/detail/kmeans.cuh +++ b/cpp/include/raft/cluster/detail/kmeans.cuh @@ -42,6 +42,7 @@ #include #include #include +#include #include #include #include @@ -109,7 +110,7 @@ void kmeansPlusPlus(const raft::handle_t& handle, auto dataBatchSize = getDataBatchSize(params.batch_samples, n_samples); // temporary buffers - std::vector h_wt(n_samples); + auto indices = raft::make_device_vector(handle, n_trials); auto centroidCandidates = raft::make_device_matrix(handle, n_trials, n_features); auto costPerCandidate = raft::make_device_vector(handle, n_trials); auto minClusterDistance = raft::make_device_vector(handle, n_samples); @@ -119,6 +120,17 @@ void kmeansPlusPlus(const raft::handle_t& handle, rmm::device_scalar clusterCost(stream); rmm::device_scalar> minClusterIndexAndDistance(stream); + // Device and matrix views + raft::device_vector_view indices_view(indices.data_handle(), n_trials); + auto const_weights_view = + raft::make_device_vector_view(minClusterDistance.data_handle(), n_samples); + auto const_indices_view = + raft::make_device_vector_view(indices.data_handle(), n_trials); + auto const_X_view = + raft::make_device_matrix_view(X.data_handle(), n_samples, n_features); + raft::device_matrix_view candidates_view( + centroidCandidates.data_handle(), n_trials, n_features); + // L2 norm of X: ||c||^2 auto L2NormX = raft::make_device_vector(handle, n_samples); @@ -133,6 +145,7 @@ void kmeansPlusPlus(const raft::handle_t& handle, stream); } + raft::random::RngState rng(params.rng_state.seed, params.rng_state.type); std::mt19937 gen(params.rng_state.seed); std::uniform_int_distribution<> dis(0, n_samples - 1); @@ -169,20 +182,9 @@ void kmeansPlusPlus(const raft::handle_t& handle, // <<< Step-3 >>> : Sample x in X with probability p_x = d^2(x, C) / phi_X (C) // Choose 'n_trials' centroid candidates from X with probability proportional to the squared // distance to the nearest existing cluster - raft::copy(h_wt.data(), minClusterDistance.data_handle(), minClusterDistance.size(), stream); - handle.sync_stream(stream); - // Note - n_trials is relative small here, we don't need raft::gather call - std::discrete_distribution<> d(h_wt.begin(), h_wt.end()); - for (int cIdx = 0; cIdx < n_trials; ++cIdx) { - auto rand_idx = d(gen); - auto randCentroid = raft::make_device_matrix_view( - X.data_handle() + n_features * rand_idx, 1, n_features); - raft::copy(centroidCandidates.data_handle() + cIdx * n_features, - randCentroid.data_handle(), - randCentroid.size(), - stream); - } + raft::random::discrete(handle, rng, indices_view, const_weights_view); + raft::matrix::gather(handle, const_X_view, const_indices_view, candidates_view); // Calculate pairwise distance between X and the centroid candidates // Output - pwd [n_trials x n_samples] diff --git a/cpp/include/raft/random/detail/rng_device.cuh b/cpp/include/raft/random/detail/rng_device.cuh index ef13138beb..6c75a4fa78 100644 --- a/cpp/include/raft/random/detail/rng_device.cuh +++ b/cpp/include/raft/random/detail/rng_device.cuh @@ -681,6 +681,40 @@ __global__ void rngKernel(DeviceState rng_state, return; } +template +__global__ void sample_with_replacement_kernel(DeviceState rng_state, + OutType* out, + const WeightType* weights_csum, + IdxType sampledLen, + IdxType len) +{ + // todo(lsugy): warp-collaborative binary search + + IdxType tid = threadIdx.x + static_cast(blockIdx.x) * blockDim.x; + GenType gen(rng_state, static_cast(tid)); + + if (tid < sampledLen) { + WeightType val_01; + gen.next(val_01); + WeightType val_search = val_01 * weights_csum[len - 1]; + + // Binary search of the first index for which the cumulative sum of weights is larger than the + // generated value + IdxType idx_start = 0; + IdxType idx_end = len; + while (idx_end > idx_start) { + IdxType idx_middle = (idx_start + idx_end) / 2; + WeightType val_middle = weights_csum[idx_middle]; + if (val_search <= val_middle) { + idx_end = idx_middle; + } else { + idx_start = idx_middle + 1; + } + } + out[tid] = static_cast(min(idx_start, len - 1)); + } +} + /** * This kernel is deprecated and should be removed in a future release */ diff --git a/cpp/include/raft/random/detail/rng_impl.cuh b/cpp/include/raft/random/detail/rng_impl.cuh index 5aecbfcaa2..0a843857e1 100644 --- a/cpp/include/raft/random/detail/rng_impl.cuh +++ b/cpp/include/raft/random/detail/rng_impl.cuh @@ -234,6 +234,49 @@ void laplace( RAFT_CALL_RNG_FUNC(rng_state, call_rng_kernel<1>, rng_state, stream, ptr, len, params); } +template +void call_sample_with_replacement_kernel(DeviceState const& dev_state, + RngState& rng_state, + cudaStream_t stream, + OutType* out, + const WeightType* weights_csum, + IdxType sampledLen, + IdxType len) +{ + IdxType n_threads = 256; + IdxType n_blocks = raft::ceildiv(sampledLen, n_threads); + sample_with_replacement_kernel<<>>( + dev_state, out, weights_csum, sampledLen, len); + rng_state.advance(uint64_t(n_blocks) * n_threads, 1); +} + +template +std::enable_if_t> discrete(RngState& rng_state, + OutType* ptr, + const WeightType* weights, + IndexType sampledLen, + IndexType len, + cudaStream_t stream) +{ + // Compute the cumulative sums of the weights + size_t temp_storage_bytes = 0; + rmm::device_uvector weights_csum(len, stream); + cub::DeviceScan::InclusiveSum(nullptr, temp_storage_bytes, weights, weights_csum.data(), len); + rmm::device_uvector temp_storage(temp_storage_bytes, stream); + cub::DeviceScan::InclusiveSum( + temp_storage.data(), temp_storage_bytes, weights, weights_csum.data(), len); + + // Sample indices with replacement + RAFT_CALL_RNG_FUNC(rng_state, + call_sample_with_replacement_kernel, + rng_state, + stream, + ptr, + weights_csum.data(), + sampledLen, + len); +} + template void sampleWithoutReplacement(RngState& rng_state, DataT* out, diff --git a/cpp/include/raft/random/rng.cuh b/cpp/include/raft/random/rng.cuh index 95bfe24a68..504b01ebc3 100644 --- a/cpp/include/raft/random/rng.cuh +++ b/cpp/include/raft/random/rng.cuh @@ -703,6 +703,46 @@ void laplace(const raft::handle_t& handle, detail::laplace(rng_state, ptr, len, mu, scale, handle.get_stream()); } +/** + * @brief Generate random integers, where the probability of i is weights[i]/sum(weights) + * + * Usage example: + * @code{.cpp} + * #include + * #include + * #include + * + * raft::handle_t handle; + * ... + * raft::random::RngState rng(seed); + * auto indices = raft::make_device_vector(handle, n_samples); + * raft::random::discrete(handle, rng, indices.view(), weights); + * @endcode + * + * @tparam OutType integer output type + * @tparam WeightType weight type + * @tparam IndexType data type used to represent length of the arrays + * + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[out] out output array + * @param[in] weights weight array + */ +template +std::enable_if_t> discrete( + const raft::handle_t& handle, + RngState& rng_state, + raft::device_vector_view out, + raft::device_vector_view weights) +{ + detail::discrete(rng_state, + out.data_handle(), + weights.data_handle(), + out.extent(0), + weights.extent(0), + handle.get_stream()); +} + namespace sample_without_replacement_impl { template struct weight_alias { diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index dae0f6f6b1..a75eb3bff6 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -189,6 +189,7 @@ if(BUILD_TESTS) test/random/multi_variable_gaussian.cu test/random/permute.cu test/random/rng.cu + test/random/rng_discrete.cu test/random/rng_int.cu test/random/rmat_rectangular_generator.cu test/random/sample_without_replacement.cu diff --git a/cpp/test/random/rng_discrete.cu b/cpp/test/random/rng_discrete.cu new file mode 100644 index 0000000000..b7aef51af5 --- /dev/null +++ b/cpp/test/random/rng_discrete.cu @@ -0,0 +1,213 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft { +namespace random { + +/* In this test we generate pseudo-random integers following a probability distribution defined by + * an array of weights, such that the probability of the integer i is p_i=w_i/sum(w). A histogram of + * the generated integers is compared to the expected probabilities. The histogram is normalized, + * i.e divided by the number of drawn integers n=sampled_len*n_repeat. The expected value for the + * index i of the histogram is E_i=p_i, the standard deviation sigma_i=sqrt(p_i*(1-p_i)/n). + * + * Weights are constructed as a sparse vector containing mostly zeros and a small number of non-zero + * values. The test tolerance used to compare the actual and expected histograms is + * eps=max(sigma_i). For the test to be relevant, the tolerance must be small w.r.t the non-zero + * probabilities. Hence, n_repeat, sampled_len and nnz must be chosen accordingly. The test + * automatically computes the tolerance and will fail if it is estimated too high for the test to be + * relevant. + */ + +template +struct RngDiscreteInputs { + IdxT n_repeat; + IdxT sampled_len; + IdxT len; + IdxT nnz; + GeneratorType gtype; + unsigned long long int seed; +}; + +template +::std::ostream& operator<<(::std::ostream& os, const RngDiscreteInputs& d) +{ + return os << "{" << d.n_repeat << ", " << d.sampled_len << ", " << d.len << ", " << d.nnz << "}"; +} + +template +void update_count( + const LabelT* labels, IdxT* count, IdxT sampled_len, IdxT len, const cudaStream_t& stream) +{ + IdxT num_levels = len + 1; + IdxT lower_level = 0; + IdxT upper_level = len; + + rmm::device_uvector temp_count(len, stream); + + size_t temp_storage_bytes = 0; + RAFT_CUDA_TRY(cub::DeviceHistogram::HistogramEven(nullptr, + temp_storage_bytes, + labels, + temp_count.data(), + num_levels, + lower_level, + upper_level, + sampled_len, + stream)); + + rmm::device_uvector workspace(temp_storage_bytes, stream); + + RAFT_CUDA_TRY(cub::DeviceHistogram::HistogramEven(workspace.data(), + temp_storage_bytes, + labels, + temp_count.data(), + num_levels, + lower_level, + upper_level, + sampled_len, + stream)); + + raft::linalg::add(count, count, temp_count.data(), len, stream); +} + +template +void normalize_count( + float* histogram, const IdxT* count, float scale, IdxT len, const cudaStream_t& stream) +{ + raft::linalg::unaryOp( + histogram, + count, + len, + [scale] __device__(const IdxT& cnt) { return static_cast(cnt) / scale; }, + stream); +} + +template +class RngDiscreteTest : public ::testing::TestWithParam> { + public: + RngDiscreteTest() + : params(::testing::TestWithParam>::GetParam()), + stream(handle.get_stream()), + out(params.sampled_len, stream), + weights(params.len, stream), + histogram(params.len, stream), + exp_histogram(params.len) + { + } + + protected: + void SetUp() override + { + tolerance = 0.0f; + std::vector h_weights(params.len, WeightT{0}); + std::mt19937 gen(params.seed); + std::uniform_real_distribution dis(WeightT{0.2}, WeightT{2.0}); + WeightT total_weight = WeightT{0}; + for (int i = 0; i < params.nnz; i++) { + h_weights[i] = dis(gen); + total_weight += h_weights[i]; + } + float min_p = 1.f; + for (int i = 0; i < params.nnz; i++) { + float p = static_cast(h_weights[i] / total_weight); + float n = static_cast(params.n_repeat * params.sampled_len); + float sigma = std::sqrt(p * (1.f - p) / n); + tolerance = std::max(tolerance, 4.f * sigma); + min_p = std::min(min_p, p); + } + EXPECT_TRUE(tolerance < 0.5f * min_p) << "Test tolerance (" << tolerance + << ") is too high. Use more samples, more " + "repetitions or less non-zero weights."; + std::shuffle(h_weights.begin(), h_weights.end(), gen); + raft::copy(weights.data(), h_weights.data(), params.len, stream); + + RngState r(params.seed, params.gtype); + raft::device_vector_view out_view(out.data(), out.size()); + auto weights_view = + raft::make_device_vector_view(weights.data(), weights.size()); + + rmm::device_uvector count(params.len, stream); + RAFT_CUDA_TRY(cudaMemsetAsync(count.data(), 0, params.len * sizeof(IdxT), stream)); + for (int iter = 0; iter < params.n_repeat; iter++) { + discrete(handle, r, out_view, weights_view); + update_count(out.data(), count.data(), params.sampled_len, params.len, stream); + } + float scale = static_cast(params.sampled_len * params.n_repeat); + normalize_count(histogram.data(), count.data(), scale, params.len, stream); + + // Compute the expected normalized histogram + for (IdxT i = 0; i < params.len; i++) { + exp_histogram[i] = h_weights[i] / total_weight; + } + } + + protected: + raft::handle_t handle; + cudaStream_t stream; + + RngDiscreteInputs params; + float tolerance; + rmm::device_uvector out; + rmm::device_uvector weights; + rmm::device_uvector histogram; + std::vector exp_histogram; +}; + +const std::vector> inputs_i32 = { + {1, 10000, 5, 5, GenPhilox, 123ULL}, + {1, 10000, 10, 7, GenPhilox, 456ULL}, + {1000, 100, 10000, 20, GenPhilox, 123ULL}, + {1, 10000, 5, 5, GenPC, 1234ULL}, +}; +const std::vector> inputs_i64 = { + {1, 10000, 5, 5, GenPhilox, 123ULL}, + {1, 10000, 10, 7, GenPhilox, 456ULL}, + {1000, 100, 10000, 20, GenPhilox, 123ULL}, + {1, 10000, 5, 5, GenPC, 1234ULL}, +}; + +#define RNG_DISCRETE_TEST(test_type, test_name, test_inputs) \ + typedef RAFT_DEPAREN(test_type) test_name; \ + TEST_P(test_name, Result) \ + { \ + ASSERT_TRUE(devArrMatchHost(exp_histogram.data(), \ + histogram.data(), \ + exp_histogram.size(), \ + CompareApprox(tolerance))); \ + } \ + INSTANTIATE_TEST_CASE_P(ReduceTests, test_name, ::testing::ValuesIn(test_inputs)) + +RNG_DISCRETE_TEST((RngDiscreteTest), RngDiscreteTestI32FI32, inputs_i32); +RNG_DISCRETE_TEST((RngDiscreteTest), RngDiscreteTestU32FI32, inputs_i32); +RNG_DISCRETE_TEST((RngDiscreteTest), RngDiscreteTestI64FI32, inputs_i32); +RNG_DISCRETE_TEST((RngDiscreteTest), RngDiscreteTestI32DI32, inputs_i32); + +// Disable IdxT=int64_t test due to CUB error: https://github.com/NVIDIA/cub/issues/192 +// RNG_DISCRETE_TEST((RngDiscreteTest), RngDiscreteTestI32FI64, inputs_i64); + +} // namespace random +} // namespace raft