From a4d67caedf7684796c66fcf4347f220690de67eb Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 3 May 2022 17:11:36 -0400 Subject: [PATCH 01/10] Updating raft rng public API to add docs and be more explicit --- cpp/include/raft/random/rng.cuh | 388 ++++++++++++++++++++++++++++++-- cpp/test/random/rng.cu | 27 +-- 2 files changed, 384 insertions(+), 31 deletions(-) diff --git a/cpp/include/raft/random/rng.cuh b/cpp/include/raft/random/rng.cuh index 33d712ac15..caf551edca 100644 --- a/cpp/include/raft/random/rng.cuh +++ b/cpp/include/raft/random/rng.cuh @@ -22,28 +22,380 @@ #include "detail/rng_impl.cuh" #include "detail/rng_impl_deprecated.cuh" // necessary for now (to be removed) #include "rng_state.hpp" +#include namespace raft { namespace random { -using detail::bernoulli; -using detail::exponential; -using detail::fill; -using detail::gumbel; -using detail::laplace; -using detail::logistic; -using detail::lognormal; -using detail::normal; -using detail::normalInt; -using detail::normalTable; -using detail::rayleigh; -using detail::scaled_bernoulli; -using detail::uniform; -using detail::uniformInt; - -using detail::sampleWithoutReplacement; - -using detail::affine_transform_params; +/** + * @brief Generate uniformly distributed numbers in the given range + * @tparam OutType data type of output random number + * @tparam LenType data type used to represent length of the arrays + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[out] ptr the output array + * @param[in] len the number of elements in the output + * @param[in] start start of the range + * @param[in] end end of the range + */ +template +void uniform(raft::handle_t& const handle, + RngState& rng_state, + OutType* ptr, + LenType len, + OutType start, + OutType end) +{ + detail::uniform(rng_state, ptr, len, start, end, handle.get_stream()); +} + +/** + * @brief Generate uniformly distributed integers in the given range + * @tparam OutType data type of output random number + * @tparam LenType data type used to represent length of the arrays + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[out] ptr the output array + * @param[in] len the number of elements in the output + * @param[in] start start of the range + * @param[in] end end of the range + */ +template +void uniformInt(raft::handle_t& const handle, + RngState& rng_state, + OutType* ptr, + LenType len, + OutType start, + OutType end) +{ + detail::uniformInt(rng_state, ptr, len, start, end, handle.get_stream()); +} + +/** + * @brief Generate normal distributed numbers + * @tparam OutType data type of output random number + * @tparam LenType data type used to represent length of the arrays + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[out[ ptr the output array + * @param[in] len the number of elements in the output + * @param[in] mu mean of the distribution + * @param[in] sigma std-dev of the distribution + */ +template +void normal(raft::handle_t& const handle, + RngState& rng_state, + OutType* ptr, + LenType len, + OutType mu, + OutType sigma) +{ + detail::normal(rng_state, ptr, len, mu, sigma, handle.get_stream()); +} + +/** + * @brief Generate normal distributed integers + * @tparam OutType data type of output random number + * @tparam LenType data type used to represent length of the arrays + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[out[ ptr the output array + * @param[in] len the number of elements in the output + * @param[in] mu mean of the distribution + * @param[in] sigma std-dev of the distribution + */ +template +void normalInt(raft::handle_t& const handle, + RngState& rng_state, + IntType* ptr, + LenType len, + IntType mu, + IntType sigma) +{ + detail::normalInt(rng_state, ptr, len, mu, sigma, handle.get_stream()); +} + +/** + * @brief Generate normal distributed table according to the given set of + * means and scalar standard deviations. + * + * Each row in this table conforms to a normally distributed n-dim vector + * whose mean is the input vector and standard deviation is the corresponding + * vector or scalar. Correlations among the dimensions itself is assumed to + * be absent. + * + * @tparam OutType data type of output random number + * @tparam LenType data type used to represent length of the arrays + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[out] ptr the output table (dim = n_rows x n_cols) + * @param[in] n_rows number of rows in the table + * @param[in] n_cols number of columns in the table + * @param[in] mu_vec mean vector (dim = n_cols x 1). + * @param[in] sigma_vec std-dev vector of each component (dim = n_cols x 1). Pass + * a nullptr to use the same scalar 'sigma' across all components + * @param[in] sigma scalar sigma to be used if 'sigma_vec' is nullptr + */ +template +void normalTable(raft::handle_t& const handle, + RngState& rng_state, + OutType* ptr, + LenType n_rows, + LenType n_cols, + const OutType* mu_vec, + const OutType* sigma_vec, + OutType sigma, + cudaStream_t stream) +{ + detail::normalTable( + rng_state, ptr, n_rows, n_cols, mu_vec, sigma_vec, sigma, handle.get_stream()); +} + +/** + * @brief Fill an array with the given value + * + * @tparam OutType data type of output random number + * @tparam LenType data type used to represent length of the arrays + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[out] ptr the output array + * @param[in] len the number of elements in the output + * @param[in] val value to be filled + */ +template +void fill(raft::handle_t& const handle, RngState& rng_state, OutType* ptr, LenType len, OutType val) +{ + detail::fill(rng_state, ptr, len, val, handle.get_stream()); +} + +/** + * @brief Generate bernoulli distributed boolean array + * + * @tparam Type data type in which to compute the probabilities + * @tparam OutType output data type + * @tparam LenType data type used to represent length of the arrays + * + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[out] ptr the output array + * @param[in] len the number of elements in the output + * @param[in] prob coin-toss probability for heads + */ +template +void bernoulli( + raft::handle_t& const handle, RngState& rng_state, OutType* ptr, LenType len, Type prob) +{ + detail::bernoulli(rng_state, ptr, len, prob, handle.get_stream()); +} + +/** + * @brief Generate bernoulli distributed array and applies scale + * + * @tparam OutType data type in which to compute the probabilities + * @tparam LenType data type used to represent length of the arrays + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[out] ptr the output array + * @param[in] len the number of elements in the output + * @param[in] prob coin-toss probability for heads + * @param[in] scale scaling factor + */ +template +void scaled_bernoulli(raft::handle_t& const handle, + RngState& rng_state, + OutType* ptr, + LenType len, + OutType prob, + OutType scale) +{ + detail::scaled_bernoulli(rng_state, ptr, len, prob, scale, handle.get_stream()); +} + +/** + * @brief Generate Gumbel distributed random numbers + * + * @tparam OutType data type of output random number + * @tparam LenType data type used to represent length of the arrays + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[out] ptr output array + * @param[in] len number of elements in the output array + * @param[in] mu mean value + * @param[in] beta scale value + * @note https://en.wikipedia.org/wiki/Gumbel_distribution + */ +template +void gumbel(raft::handle_t& const handle, + RngState& rng_state, + OutType* ptr, + LenType len, + OutType mu, + OutType beta) +{ + detail::gumbel(rng_state, ptr, len, mu, beta, handle.get_stream()); +} + +/** + * @brief Generate lognormal distributed numbers + * + * @tparam OutType data type of output random number + * @tparam LenType data type used to represent length of the arrays + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[out] ptr the output array + * @param[in] len the number of elements in the output + * @param[in] mu mean of the distribution + * @param[in] sigma std-dev of the distribution + */ +template +void lognormal(raft::handle_t& const handle, + RngState& rng_state, + OutType* ptr, + LenType len, + OutType mu, + OutType sigma) +{ + detail::lognormal(rng_state, ptr, len, mu, sigma, handle.get_stream()); +} + +/** + * @brief Generate logistic distributed random numbers + * + * @tparam OutType data type of output random number + * @tparam LenType data type used to represent length of the arrays + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[out] ptr output array + * @param[in] len number of elements in the output array + * @param[in] mu mean value + * @param[in] scale scale value + */ +template +void logistic(raft::handle_t& const handle, + RngState& rng_state, + OutType* ptr, + LenType len, + OutType mu, + OutType scale) +{ + detail::logistic(rng_state, ptr, len, mu, scale, handle.get_stream()); +} + +/** + * @brief Generate exponentially distributed random numbers + * + * @tparam OutType data type of output random number + * @tparam LenType data type used to represent length of the arrays + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[out] ptr output array + * @param[in] len number of elements in the output array + * @param[in] lambda the lambda + */ +template +void exponential( + raft::handle_t& const handle, RngState& rng_state, OutType* ptr, LenType len, OutType lambda) +{ + detail::exponential(rng_state, ptr, len, lambda, handle.get_stream()); +} + +/** + * @brief Generate rayleigh distributed random numbers + * + * @tparam OutType data type of output random number + * @tparam LenType data type used to represent length of the arrays + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[out] ptr output array + * @param[in] len number of elements in the output array + * @param[in] sigma the sigma + */ +template +void rayleigh( + raft::handle_t& const handle, RngState& rng_state, OutType* ptr, LenType len, OutType sigma) +{ + detail::rayleigh(rng_state, ptr, len, sigma, handle.get_stream()); +} + +/** + * @brief Generate laplace distributed random numbers + * + * @tparam OutType data type of output random number + * @tparam LenType data type used to represent length of the arrays + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[out] ptr output array + * @param[in] len number of elements in the output array + * @param[in] mu the mean + * @param[in] scale the scale + */ +template +void laplace(raft::handle_t& const handle, + RngState& rng_state, + OutType* ptr, + LenType len, + OutType mu, + OutType scale) +{ + detail::laplace(rng_state, ptr, len, mu, scale, handle.get_stream()); +} + +/** + * @brief Sample the input array without replacement, optionally based on the + * input weight vector for each element in the array + * + * Implementation here is based on the `one-pass sampling` algo described here: + * https://www.ethz.ch/content/dam/ethz/special-interest/baug/ivt/ivt-dam/vpl/reports/1101-1200/ab1141.pdf + * + * @note In the sampled array the elements which are picked will always appear + * in the increasing order of their weights as computed using the exponential + * distribution. So, if you're particular about the order (for eg. array + * permutations), then this might not be the right choice! + * + * @tparam DataT data type + * @tparam WeightsT weights type + * @tparam IdxT index type + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[out] out output sampled array (of length 'sampledLen') + * @param[out] outIdx indices of the sampled array (of length 'sampledLen'). Pass + * a nullptr if this is not required. + * @param[in] in input array to be sampled (of length 'len') + * @param[in] wts weights array (of length 'len'). Pass a nullptr if uniform + * sampling is desired + * @param[in] sampledLen output sampled array length + * @param[in] len input array length + */ +template +void sampleWithoutReplacement(raft::handle_t& const handle, + RngState& rng_state, + DataT* out, + IdxT* outIdx, + const DataT* in, + const WeightsT* wts, + IdxT sampledLen, + IdxT len) +{ + detail::sampleWithoutReplacement( + rng_state, out, outIdx, in, wts, sampledLen, len, handle.get_stream()); +} + +/** + * @brief Generates the 'a' and 'b' parameters for a modulo affine + * transformation equation: `(ax + b) % n` + * + * @tparam IdxT integer type + * + * @param[in] rng_state random number generator state + * @param[in] n the modulo range + * @param[out] a slope parameter + * @param[out] b intercept parameter + */ +template +void affine_transform_params(RngState const& rng_state, IdxT n, IdxT& a, IdxT& b) +{ + detail::affine_transform_params(rng_state, n, a, b); +} /////////////////////////////////////////////////////////////////////////// // Everything below this point is deprecated and will be removed // diff --git a/cpp/test/random/rng.cu b/cpp/test/random/rng.cu index ffae76b8e6..12479c6fee 100644 --- a/cpp/test/random/rng.cu +++ b/cpp/test/random/rng.cu @@ -107,21 +107,21 @@ class RngTest : public ::testing::TestWithParam> { { RngState r(params.seed, params.gtype); switch (params.type) { - case RNG_Normal: normal(r, data.data(), params.len, params.start, params.end, stream); break; + case RNG_Normal: normal(handle, r, data.data(), params.len, params.start, params.end); break; case RNG_LogNormal: - lognormal(r, data.data(), params.len, params.start, params.end, stream); + lognormal(handle, r, data.data(), params.len, params.start, params.end); break; case RNG_Uniform: - uniform(r, data.data(), params.len, params.start, params.end, stream); + uniform(handle, r, data.data(), params.len, params.start, params.end); break; - case RNG_Gumbel: gumbel(r, data.data(), params.len, params.start, params.end, stream); break; + case RNG_Gumbel: gumbel(handle, r, data.data(), params.len, params.start, params.end); break; case RNG_Logistic: - logistic(r, data.data(), params.len, params.start, params.end, stream); + logistic(handle, r, data.data(), params.len, params.start, params.end); break; - case RNG_Exp: exponential(r, data.data(), params.len, params.start, stream); break; - case RNG_Rayleigh: rayleigh(r, data.data(), params.len, params.start, stream); break; + case RNG_Exp: exponential(handle, r, data.data(), params.len, params.start); break; + case RNG_Rayleigh: rayleigh(handle, r, data.data(), params.len, params.start); break; case RNG_Laplace: - laplace(r, data.data(), params.len, params.start, params.end, stream); + laplace(handle, r, data.data(), params.len, params.start, params.end); break; }; static const int threads = 128; @@ -292,7 +292,8 @@ TEST(Rng, MeanError) int num_experiments = 1024; int len = num_samples * num_experiments; - cudaStream_t stream; + raft::handle_t handle; + auto stream = handle.get_stream(); RAFT_CUDA_TRY(cudaStreamCreate(&stream)); rmm::device_uvector data(len, stream); @@ -301,7 +302,7 @@ TEST(Rng, MeanError) for (auto rtype : {GenPhilox, GenPC}) { RngState r(seed, rtype); - normal(r, data.data(), len, 3.3f, 0.23f, stream); + normal(handle, r, data.data(), len, 3.3f, 0.23f); // uniform(r, data, len, -1.0, 2.0); raft::stats::mean( mean_result.data(), data.data(), num_samples, num_experiments, false, false, stream); @@ -349,7 +350,7 @@ class ScaledBernoulliTest : public ::testing::Test { { RAFT_CUDA_TRY(cudaStreamCreate(&stream)); RngState r(42); - scaled_bernoulli(r, data.data(), len, T(0.5), T(scale), stream); + scaled_bernoulli(handle, r, data.data(), len, T(0.5), T(scale)); } void rangeCheck() @@ -382,7 +383,7 @@ class BernoulliTest : public ::testing::Test { void SetUp() override { RngState r(42); - bernoulli(r, data.data(), len, T(0.5), stream); + bernoulli(handle, r, data.data(), len, T(0.5)); RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); } @@ -448,7 +449,7 @@ class RngNormalTableTest : public ::testing::TestWithParam <<>>(stats.data(), data.data(), len); From 377e9cc0ceb068cadb73e09b266ca6e4676ace0f Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 3 May 2022 17:21:57 -0400 Subject: [PATCH 02/10] Fixing compile error --- cpp/include/raft/random/rng.cuh | 45 +++-- cpp/include/raft/spatial/knn/ball_cover.hpp | 177 +----------------- .../raft/spatial/knn/detail/ball_cover.cuh | 6 +- 3 files changed, 26 insertions(+), 202 deletions(-) diff --git a/cpp/include/raft/random/rng.cuh b/cpp/include/raft/random/rng.cuh index caf551edca..3c30200923 100644 --- a/cpp/include/raft/random/rng.cuh +++ b/cpp/include/raft/random/rng.cuh @@ -14,9 +14,6 @@ * limitations under the License. */ -#ifndef __RNG_H -#define __RNG_H - #pragma once #include "detail/rng_impl.cuh" @@ -24,11 +21,11 @@ #include "rng_state.hpp" #include -namespace raft { -namespace random { +namespace raft::random { /** * @brief Generate uniformly distributed numbers in the given range + * * @tparam OutType data type of output random number * @tparam LenType data type used to represent length of the arrays * @param[in] handle raft handle for resource management @@ -39,7 +36,7 @@ namespace random { * @param[in] end end of the range */ template -void uniform(raft::handle_t& const handle, +void uniform(const raft::handle_t& handle, RngState& rng_state, OutType* ptr, LenType len, @@ -51,6 +48,7 @@ void uniform(raft::handle_t& const handle, /** * @brief Generate uniformly distributed integers in the given range + * * @tparam OutType data type of output random number * @tparam LenType data type used to represent length of the arrays * @param[in] handle raft handle for resource management @@ -61,7 +59,7 @@ void uniform(raft::handle_t& const handle, * @param[in] end end of the range */ template -void uniformInt(raft::handle_t& const handle, +void uniformInt(const raft::handle_t& handle, RngState& rng_state, OutType* ptr, LenType len, @@ -73,6 +71,7 @@ void uniformInt(raft::handle_t& const handle, /** * @brief Generate normal distributed numbers + * * @tparam OutType data type of output random number * @tparam LenType data type used to represent length of the arrays * @param[in] handle raft handle for resource management @@ -83,7 +82,7 @@ void uniformInt(raft::handle_t& const handle, * @param[in] sigma std-dev of the distribution */ template -void normal(raft::handle_t& const handle, +void normal(const raft::handle_t& handle, RngState& rng_state, OutType* ptr, LenType len, @@ -95,6 +94,7 @@ void normal(raft::handle_t& const handle, /** * @brief Generate normal distributed integers + * * @tparam OutType data type of output random number * @tparam LenType data type used to represent length of the arrays * @param[in] handle raft handle for resource management @@ -105,7 +105,7 @@ void normal(raft::handle_t& const handle, * @param[in] sigma std-dev of the distribution */ template -void normalInt(raft::handle_t& const handle, +void normalInt(const raft::handle_t& handle, RngState& rng_state, IntType* ptr, LenType len, @@ -137,7 +137,7 @@ void normalInt(raft::handle_t& const handle, * @param[in] sigma scalar sigma to be used if 'sigma_vec' is nullptr */ template -void normalTable(raft::handle_t& const handle, +void normalTable(const raft::handle_t& handle, RngState& rng_state, OutType* ptr, LenType n_rows, @@ -163,7 +163,7 @@ void normalTable(raft::handle_t& const handle, * @param[in] val value to be filled */ template -void fill(raft::handle_t& const handle, RngState& rng_state, OutType* ptr, LenType len, OutType val) +void fill(const raft::handle_t& handle, RngState& rng_state, OutType* ptr, LenType len, OutType val) { detail::fill(rng_state, ptr, len, val, handle.get_stream()); } @@ -183,7 +183,7 @@ void fill(raft::handle_t& const handle, RngState& rng_state, OutType* ptr, LenTy */ template void bernoulli( - raft::handle_t& const handle, RngState& rng_state, OutType* ptr, LenType len, Type prob) + const raft::handle_t& handle, RngState& rng_state, OutType* ptr, LenType len, Type prob) { detail::bernoulli(rng_state, ptr, len, prob, handle.get_stream()); } @@ -201,7 +201,7 @@ void bernoulli( * @param[in] scale scaling factor */ template -void scaled_bernoulli(raft::handle_t& const handle, +void scaled_bernoulli(const raft::handle_t& handle, RngState& rng_state, OutType* ptr, LenType len, @@ -225,7 +225,7 @@ void scaled_bernoulli(raft::handle_t& const handle, * @note https://en.wikipedia.org/wiki/Gumbel_distribution */ template -void gumbel(raft::handle_t& const handle, +void gumbel(const raft::handle_t& handle, RngState& rng_state, OutType* ptr, LenType len, @@ -248,7 +248,7 @@ void gumbel(raft::handle_t& const handle, * @param[in] sigma std-dev of the distribution */ template -void lognormal(raft::handle_t& const handle, +void lognormal(const raft::handle_t& handle, RngState& rng_state, OutType* ptr, LenType len, @@ -271,7 +271,7 @@ void lognormal(raft::handle_t& const handle, * @param[in] scale scale value */ template -void logistic(raft::handle_t& const handle, +void logistic(const raft::handle_t& handle, RngState& rng_state, OutType* ptr, LenType len, @@ -294,7 +294,7 @@ void logistic(raft::handle_t& const handle, */ template void exponential( - raft::handle_t& const handle, RngState& rng_state, OutType* ptr, LenType len, OutType lambda) + const raft::handle_t& handle, RngState& rng_state, OutType* ptr, LenType len, OutType lambda) { detail::exponential(rng_state, ptr, len, lambda, handle.get_stream()); } @@ -312,7 +312,7 @@ void exponential( */ template void rayleigh( - raft::handle_t& const handle, RngState& rng_state, OutType* ptr, LenType len, OutType sigma) + const raft::handle_t& handle, RngState& rng_state, OutType* ptr, LenType len, OutType sigma) { detail::rayleigh(rng_state, ptr, len, sigma, handle.get_stream()); } @@ -330,7 +330,7 @@ void rayleigh( * @param[in] scale the scale */ template -void laplace(raft::handle_t& const handle, +void laplace(const raft::handle_t& handle, RngState& rng_state, OutType* ptr, LenType len, @@ -367,7 +367,7 @@ void laplace(raft::handle_t& const handle, * @param[in] len input array length */ template -void sampleWithoutReplacement(raft::handle_t& const handle, +void sampleWithoutReplacement(const raft::handle_t& handle, RngState& rng_state, DataT* out, IdxT* outIdx, @@ -702,7 +702,4 @@ class DEPR Rng : public detail::RngImpl { #undef DEPR -}; // end namespace random -}; // end namespace raft - -#endif +}; // end namespace raft::random diff --git a/cpp/include/raft/spatial/knn/ball_cover.hpp b/cpp/include/raft/spatial/knn/ball_cover.hpp index a7c483493e..fd09426594 100644 --- a/cpp/include/raft/spatial/knn/ball_cover.hpp +++ b/cpp/include/raft/spatial/knn/ball_cover.hpp @@ -14,183 +14,10 @@ * limitations under the License. */ /** - * This file is deprecated and will be removed in release 22.06. + * This file is deprecated and will be removed in release 22.08. * Please use the cuh version instead. */ -#ifndef __BALL_COVER_H -#define __BALL_COVER_H - #pragma once -#include - -#include "ball_cover_common.h" -#include "detail/ball_cover.cuh" -#include "detail/ball_cover/common.cuh" -#include -#include - -namespace raft { -namespace spatial { -namespace knn { - -template -void rbc_build_index(const raft::handle_t& handle, - BallCoverIndex& index) -{ - ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); - if (index.metric == raft::distance::DistanceType::Haversine) { - detail::rbc_build_index(handle, index, detail::HaversineFunc()); - } else if (index.metric == raft::distance::DistanceType::L2SqrtExpanded || - index.metric == raft::distance::DistanceType::L2SqrtUnexpanded) { - detail::rbc_build_index(handle, index, detail::EuclideanFunc()); - } else { - RAFT_FAIL("Metric not support"); - } - - index.set_index_trained(); -} - -/** - * Performs a faster exact knn in metric spaces using the triangle - * inequality with a number of landmark points to reduce the - * number of distance computations from O(n^2) to O(sqrt(n)). This - * performs an all neighbors knn, which can reuse memory when - * the index and query are the same array. This function will - * build the index and assumes rbc_build_index() has not already - * been called. - * @tparam value_idx knn index type - * @tparam value_t knn distance type - * @tparam value_int type for integers, such as number of rows/cols - * @param handle raft handle for resource management - * @param index ball cover index which has not yet been built - * @param k number of nearest neighbors to find - * @param perform_post_filtering if this is false, only the closest k landmarks - * are considered (which will return approximate - * results). - * @param[out] inds output knn indices - * @param[out] dists output knn distances - * @param weight a weight for overlap between the closest landmark and - * the radius of other landmarks when pruning distances. - * Setting this value below 1 can effectively turn off - * computing distances against many other balls, enabling - * approximate nearest neighbors. Recall can be adjusted - * based on how many relevant balls are ignored. Note that - * many datasets can still have great recall even by only - * looking in the closest landmark. - */ -template -void rbc_all_knn_query(const raft::handle_t& handle, - BallCoverIndex& index, - value_int k, - value_idx* inds, - value_t* dists, - bool perform_post_filtering = true, - float weight = 1.0) -{ - ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); - if (index.metric == raft::distance::DistanceType::Haversine) { - detail::rbc_all_knn_query(handle, - index, - k, - inds, - dists, - detail::HaversineFunc(), - perform_post_filtering, - weight); - } else if (index.metric == raft::distance::DistanceType::L2SqrtExpanded || - index.metric == raft::distance::DistanceType::L2SqrtUnexpanded) { - detail::rbc_all_knn_query(handle, - index, - k, - inds, - dists, - detail::EuclideanFunc(), - perform_post_filtering, - weight); - } else { - RAFT_FAIL("Metric not supported"); - } - - index.set_index_trained(); -} - -/** - * Performs a faster exact knn in metric spaces using the triangle - * inequality with a number of landmark points to reduce the - * number of distance computations from O(n^2) to O(sqrt(n)). This - * function does not build the index and assumes rbc_build_index() has - * already been called. Use this function when the index and - * query arrays are different, otherwise use rbc_all_knn_query(). - * @tparam value_idx index type - * @tparam value_t distances type - * @tparam value_int integer type for size info - * @param handle raft handle for resource management - * @param index ball cover index which has not yet been built - * @param k number of nearest neighbors to find - * @param query the - * @param perform_post_filtering if this is false, only the closest k landmarks - * are considered (which will return approximate - * results). - * @param[out] inds output knn indices - * @param[out] dists output knn distances - * @param weight a weight for overlap between the closest landmark and - * the radius of other landmarks when pruning distances. - * Setting this value below 1 can effectively turn off - * computing distances against many other balls, enabling - * approximate nearest neighbors. Recall can be adjusted - * based on how many relevant balls are ignored. Note that - * many datasets can still have great recall even by only - * looking in the closest landmark. - * @param[in] n_query_pts number of query points - */ -template -void rbc_knn_query(const raft::handle_t& handle, - BallCoverIndex& index, - value_int k, - const value_t* query, - value_int n_query_pts, - value_idx* inds, - value_t* dists, - bool perform_post_filtering = true, - float weight = 1.0) -{ - ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); - if (index.metric == raft::distance::DistanceType::Haversine) { - detail::rbc_knn_query(handle, - index, - k, - query, - n_query_pts, - inds, - dists, - detail::HaversineFunc(), - perform_post_filtering, - weight); - } else if (index.metric == raft::distance::DistanceType::L2SqrtExpanded || - index.metric == raft::distance::DistanceType::L2SqrtUnexpanded) { - detail::rbc_knn_query(handle, - index, - k, - query, - n_query_pts, - inds, - dists, - detail::EuclideanFunc(), - perform_post_filtering, - weight); - } else { - RAFT_FAIL("Metric not supported"); - } -} - -// TODO: implement functions for: -// 4. rbc_eps_neigh() - given a populated index, perform query against different query array -// 5. rbc_all_eps_neigh() - populate a BallCoverIndex and query against training data - -} // namespace knn -} // namespace spatial -} // namespace raft - -#endif \ No newline at end of file +#include diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh index 4f8c66e05d..6200408539 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh @@ -84,14 +84,14 @@ void sample_landmarks(const raft::handle_t& handle, * 1. Randomly sample sqrt(n) points from X */ raft::random::RngState rng_state(12345); - raft::random::sampleWithoutReplacement(rng_state, + raft::random::sampleWithoutReplacement(handle, + rng_state, R_indices.data(), R_1nn_cols2.data(), index.get_R_1nn_cols(), R_1nn_ones.data(), (value_idx)index.n_landmarks, - (value_idx)index.m, - handle.get_stream()); + (value_idx)index.m); raft::matrix::copyRows(index.get_X(), index.m, From 99a7ae7d47eaca71f8d19aea436f7861467133fc Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 3 May 2022 18:34:09 -0400 Subject: [PATCH 03/10] Updating tests --- cpp/include/raft/linalg/detail/rsvd.cuh | 2 +- cpp/include/raft/random/detail/make_blobs.cuh | 4 +-- cpp/include/raft/random/rng.cuh | 3 +- cpp/test/distance/dist_adj.cu | 4 +-- cpp/test/distance/distance_base.cuh | 16 +++++----- cpp/test/distance/fused_l2_nn.cu | 4 +-- cpp/test/linalg/add.cu | 4 +-- cpp/test/linalg/binary_op.cu | 4 +-- cpp/test/linalg/coalesced_reduction.cu | 2 +- cpp/test/linalg/divide.cu | 2 +- cpp/test/linalg/eig.cu | 2 +- cpp/test/linalg/eltwise.cu | 6 ++-- cpp/test/linalg/gemm_layout.cu | 4 +-- cpp/test/linalg/gemv.cu | 4 +-- cpp/test/linalg/map.cu | 6 ++-- cpp/test/linalg/map_then_reduce.cu | 4 +-- cpp/test/linalg/matrix_vector_op.cu | 6 ++-- cpp/test/linalg/multiply.cu | 2 +- cpp/test/linalg/norm.cu | 4 +-- cpp/test/linalg/power.cu | 4 +-- cpp/test/linalg/reduce.cu | 2 +- cpp/test/linalg/reduce_cols_by_key.cu | 15 ++++----- cpp/test/linalg/reduce_rows_by_key.cu | 6 ++-- cpp/test/linalg/rsvd.cu | 2 +- cpp/test/linalg/sqrt.cu | 2 +- cpp/test/linalg/strided_reduction.cu | 3 +- cpp/test/linalg/subtract.cu | 4 +-- cpp/test/linalg/ternary_op.cu | 10 +++--- cpp/test/linalg/unary_op.cu | 2 +- cpp/test/matrix/linewise_op.cu | 2 +- cpp/test/matrix/math.cu | 6 ++-- cpp/test/matrix/matrix.cu | 2 +- cpp/test/random/make_blobs.cu | 2 +- cpp/test/random/permute.cu | 7 +++-- cpp/test/random/rng.cu | 4 +-- cpp/test/random/rng_int.cu | 2 +- cpp/test/random/sample_without_replacement.cu | 6 ++-- cpp/test/sparse/filter.cu | 9 +++--- cpp/test/sparse/sort.cu | 8 ++--- cpp/test/spatial/fused_l2_knn.cu | 4 +-- cpp/test/spatial/selection.cu | 5 +-- cpp/test/stats/cov.cu | 2 +- cpp/test/stats/dispersion.cu | 9 +++--- cpp/test/stats/histogram.cu | 7 +++-- cpp/test/stats/mean.cu | 2 +- cpp/test/stats/mean_center.cu | 2 +- cpp/test/stats/meanvar.cu | 2 +- cpp/test/stats/minmax.cu | 12 +++---- cpp/test/stats/stddev.cu | 2 +- cpp/test/stats/weighted_mean.cu | 31 +++++++++---------- 50 files changed, 126 insertions(+), 133 deletions(-) diff --git a/cpp/include/raft/linalg/detail/rsvd.cuh b/cpp/include/raft/linalg/detail/rsvd.cuh index 847209a09b..5487aead19 100644 --- a/cpp/include/raft/linalg/detail/rsvd.cuh +++ b/cpp/include/raft/linalg/detail/rsvd.cuh @@ -93,7 +93,7 @@ void rsvdFixedRank(const raft::handle_t& handle, // build random matrix rmm::device_uvector RN(n * l, stream); raft::random::RngState state{484}; - raft::random::normal(state, RN.data(), n * l, math_t(0.0), alpha, stream); + raft::random::normal(handle, state, RN.data(), n * l, math_t(0.0), alpha); // multiply to get matrix of random samples Y rmm::device_uvector Y(m * l, stream); diff --git a/cpp/include/raft/random/detail/make_blobs.cuh b/cpp/include/raft/random/detail/make_blobs.cuh index 4315aec1b7..f214abce58 100644 --- a/cpp/include/raft/random/detail/make_blobs.cuh +++ b/cpp/include/raft/random/detail/make_blobs.cuh @@ -40,7 +40,7 @@ void generate_labels(IdxT* labels, cudaStream_t stream) { IdxT a, b; - affine_transform_params(r, n_clusters, a, b); + raft::random::affine_transform_params(r, n_clusters, a, b); auto op = [=] __device__(IdxT * ptr, IdxT idx) { if (shuffle) { idx = IdxT((a * int64_t(idx)) + b); } idx %= n_clusters; @@ -230,7 +230,7 @@ void make_blobs_caller(DataT* out, const DataT* _centers; if (centers == nullptr) { rand_centers.resize(n_clusters * n_cols, stream); - raft::random::uniform( + detail::uniform( r, rand_centers.data(), n_clusters * n_cols, center_box_min, center_box_max, stream); _centers = rand_centers.data(); } else { diff --git a/cpp/include/raft/random/rng.cuh b/cpp/include/raft/random/rng.cuh index 3c30200923..faa18cbcd9 100644 --- a/cpp/include/raft/random/rng.cuh +++ b/cpp/include/raft/random/rng.cuh @@ -144,8 +144,7 @@ void normalTable(const raft::handle_t& handle, LenType n_cols, const OutType* mu_vec, const OutType* sigma_vec, - OutType sigma, - cudaStream_t stream) + OutType sigma) { detail::normalTable( rng_state, ptr, n_rows, n_cols, mu_vec, sigma_vec, sigma, handle.get_stream()); diff --git a/cpp/test/distance/dist_adj.cu b/cpp/test/distance/dist_adj.cu index 36db82af3d..16c6e11719 100644 --- a/cpp/test/distance/dist_adj.cu +++ b/cpp/test/distance/dist_adj.cu @@ -102,8 +102,8 @@ class DistanceAdjTest : public ::testing::TestWithParam x(m * k, stream); rmm::device_uvector y(n * k, stream); - uniform(r, x.data(), m * k, DataType(-1.0), DataType(1.0), stream); - uniform(r, y.data(), n * k, DataType(-1.0), DataType(1.0), stream); + uniform(handle, r, x.data(), m * k, DataType(-1.0), DataType(1.0)); + uniform(handle, r, y.data(), n * k, DataType(-1.0), DataType(1.0)); DataType threshold = params.eps; diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index 970abdc076..07643bc4ea 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -436,17 +436,17 @@ class DistanceTest : public ::testing::TestWithParam> { distanceType == raft::distance::DistanceType::JensenShannon || distanceType == raft::distance::DistanceType::KLDivergence) { // Hellinger works only on positive numbers - uniform(r, x.data(), m * k, DataType(0.0), DataType(1.0), stream); - uniform(r, y.data(), n * k, DataType(0.0), DataType(1.0), stream); + uniform(handle, r, x.data(), m * k, DataType(0.0), DataType(1.0)); + uniform(handle, r, y.data(), n * k, DataType(0.0), DataType(1.0)); } else if (distanceType == raft::distance::DistanceType::RusselRaoExpanded) { - uniform(r, x.data(), m * k, DataType(0.0), DataType(1.0), stream); - uniform(r, y.data(), n * k, DataType(0.0), DataType(1.0), stream); + uniform(handle, r, x.data(), m * k, DataType(0.0), DataType(1.0)); + uniform(handle, r, y.data(), n * k, DataType(0.0), DataType(1.0)); // Russel rao works on boolean values. - bernoulli(r, x.data(), m * k, 0.5f, stream); - bernoulli(r, y.data(), n * k, 0.5f, stream); + bernoulli(handle, r, x.data(), m * k, 0.5f); + bernoulli(handle, r, y.data(), n * k, 0.5f); } else { - uniform(r, x.data(), m * k, DataType(-1.0), DataType(1.0), stream); - uniform(r, y.data(), n * k, DataType(-1.0), DataType(1.0), stream); + uniform(handle, r, x.data(), m * k, DataType(-1.0), DataType(1.0)); + uniform(handle, r, y.data(), n * k, DataType(-1.0), DataType(1.0)); } naiveDistance( dist_ref.data(), x.data(), y.data(), m, n, k, distanceType, isRowMajor, metric_arg, stream); diff --git a/cpp/test/distance/fused_l2_nn.cu b/cpp/test/distance/fused_l2_nn.cu index 5b25747c1a..192f0c9a74 100644 --- a/cpp/test/distance/fused_l2_nn.cu +++ b/cpp/test/distance/fused_l2_nn.cu @@ -127,8 +127,8 @@ class FusedL2NNTest : public ::testing::TestWithParam> { int m = params.m; int n = params.n; int k = params.k; - uniform(r, x.data(), m * k, DataT(-1.0), DataT(1.0), stream); - uniform(r, y.data(), n * k, DataT(-1.0), DataT(1.0), stream); + uniform(handle, r, x.data(), m * k, DataT(-1.0), DataT(1.0)); + uniform(handle, r, y.data(), n * k, DataT(-1.0), DataT(1.0)); generateGoldenResult(); raft::linalg::rowNorm(xn.data(), x.data(), k, m, raft::linalg::L2Norm, true, stream); raft::linalg::rowNorm(yn.data(), y.data(), k, n, raft::linalg::L2Norm, true, stream); diff --git a/cpp/test/linalg/add.cu b/cpp/test/linalg/add.cu index fcebe2b8e0..ba9dac5ac2 100644 --- a/cpp/test/linalg/add.cu +++ b/cpp/test/linalg/add.cu @@ -43,8 +43,8 @@ class AddTest : public ::testing::TestWithParam> { params = ::testing::TestWithParam>::GetParam(); raft::random::RngState r{params.seed}; int len = params.len; - uniform(r, in1.data(), len, InT(-1.0), InT(1.0), stream); - uniform(r, in2.data(), len, InT(-1.0), InT(1.0), stream); + uniform(handle, r, in1.data(), len, InT(-1.0), InT(1.0)); + uniform(handle, r, in2.data(), len, InT(-1.0), InT(1.0)); naiveAddElem(out_ref.data(), in1.data(), in2.data(), len, stream); add(out.data(), in1.data(), in2.data(), len, stream); handle.sync_stream(stream); diff --git a/cpp/test/linalg/binary_op.cu b/cpp/test/linalg/binary_op.cu index e265ea8b11..cd4340f5cd 100644 --- a/cpp/test/linalg/binary_op.cu +++ b/cpp/test/linalg/binary_op.cu @@ -54,8 +54,8 @@ class BinaryOpTest : public ::testing::TestWithParam> { // test code for comparing two methods len = params.n * params.n; - uniform(r, cov_matrix_large.data(), len, T(-1.0), T(1.0), stream); + uniform(handle, r, cov_matrix_large.data(), len, T(-1.0), T(1.0)); eigDC(handle, cov_matrix_large.data(), diff --git a/cpp/test/linalg/eltwise.cu b/cpp/test/linalg/eltwise.cu index daf1418544..e2bc80eefe 100644 --- a/cpp/test/linalg/eltwise.cu +++ b/cpp/test/linalg/eltwise.cu @@ -73,7 +73,7 @@ class ScalarMultiplyTest : public ::testing::TestWithParam> { params = ::testing::TestWithParam>::GetParam(); raft::random::RngState r(params.seed); int len = params.len; - uniform(r, in1, len, T(-1.0), T(1.0), stream); - uniform(r, in2, len, T(-1.0), T(1.0), stream); + uniform(handle, r, in1, len, T(-1.0), T(1.0)); + uniform(handle, r, in2, len, T(-1.0), T(1.0)); naiveAdd(out_ref, in1, in2, len, stream); eltwiseAdd(out, in1, in2, len, stream); handle.sync_stream(stream); diff --git a/cpp/test/linalg/gemm_layout.cu b/cpp/test/linalg/gemm_layout.cu index e87b6d2570..5881d25b60 100644 --- a/cpp/test/linalg/gemm_layout.cu +++ b/cpp/test/linalg/gemm_layout.cu @@ -85,8 +85,8 @@ class GemmLayoutTest : public ::testing::TestWithParam> { RAFT_CUDA_TRY(cudaMalloc(&refZ, zElems * sizeof(T))); RAFT_CUDA_TRY(cudaMalloc(&Z, zElems * sizeof(T))); - uniform(r, X, xElems, T(-10.0), T(10.0), stream); - uniform(r, Y, yElems, T(-10.0), T(10.0), stream); + uniform(handle, r, X, xElems, T(-10.0), T(10.0)); + uniform(handle, r, Y, yElems, T(-10.0), T(10.0)); dim3 blocks(raft::ceildiv(params.M, 128), raft::ceildiv(params.N, 4), 1); dim3 threads(128, 4, 1); diff --git a/cpp/test/linalg/gemv.cu b/cpp/test/linalg/gemv.cu index e222dd00eb..97f5f6de94 100644 --- a/cpp/test/linalg/gemv.cu +++ b/cpp/test/linalg/gemv.cu @@ -100,8 +100,8 @@ class GemvTest : public ::testing::TestWithParam> { refy.resize(yElems, stream); y.resize(yElems, stream); - uniform(r, x.data(), xElems, T(-10.0), T(10.0), stream); - uniform(r, A.data(), aElems, T(-10.0), T(10.0), stream); + uniform(handle, r, x.data(), xElems, T(-10.0), T(10.0)); + uniform(handle, r, A.data(), aElems, T(-10.0), T(10.0)); dim3 blocks(raft::ceildiv(yElems, 256), 1, 1); dim3 threads(256, 1, 1); diff --git a/cpp/test/linalg/map.cu b/cpp/test/linalg/map.cu index 217378d224..bcaacb3c8f 100644 --- a/cpp/test/linalg/map.cu +++ b/cpp/test/linalg/map.cu @@ -87,9 +87,9 @@ class MapTest : public ::testing::TestWithParam> { { raft::random::RngState r(params.seed); auto len = params.len; - uniform(r, in.data(), len, InType(-1.0), InType(1.0), stream); + uniform(handle, r, in.data(), len, InType(-1.0), InType(1.0)); mapReduceLaunch(out_ref.data(), out.data(), in.data(), len, stream); handle.sync_stream(stream); } @@ -140,7 +140,7 @@ class MapGenericReduceTest : public ::testing::Test { void initInput(InType* input, int n, cudaStream_t stream) { raft::random::RngState r(137); - uniform(r, input, n, InType(2), InType(3), handle.get_stream()); + uniform(handle, r, input, n, InType(2), InType(3)); InType val = 1; raft::update_device(input + 42, &val, 1, handle.get_stream()); val = 5; diff --git a/cpp/test/linalg/matrix_vector_op.cu b/cpp/test/linalg/matrix_vector_op.cu index 8f5e3c7fa1..b01b3a1ca1 100644 --- a/cpp/test/linalg/matrix_vector_op.cu +++ b/cpp/test/linalg/matrix_vector_op.cu @@ -99,9 +99,9 @@ class MatVecOpTest : public ::testing::TestWithParam> IdxType N = params.rows, D = params.cols; IdxType len = N * D; IdxType vecLen = params.bcastAlongRows ? D : N; - uniform(r, in.data(), len, (T)-1.0, (T)1.0, stream); - uniform(r, vec1.data(), vecLen, (T)-1.0, (T)1.0, stream); - uniform(r, vec2.data(), vecLen, (T)-1.0, (T)1.0, stream); + uniform(handle, r, in.data(), len, (T)-1.0, (T)1.0); + uniform(handle, r, vec1.data(), vecLen, (T)-1.0, (T)1.0); + uniform(handle, r, vec2.data(), vecLen, (T)-1.0, (T)1.0); if (params.useTwoVectors) { naiveMatVec(out_ref.data(), in.data(), diff --git a/cpp/test/linalg/multiply.cu b/cpp/test/linalg/multiply.cu index 2d5e191199..e91201aa12 100644 --- a/cpp/test/linalg/multiply.cu +++ b/cpp/test/linalg/multiply.cu @@ -42,7 +42,7 @@ class MultiplyTest : public ::testing::TestWithParam> { params = ::testing::TestWithParam>::GetParam(); raft::random::RngState r(params.seed); int len = params.len; - uniform(r, in.data(), len, T(-1.0), T(1.0), stream); + uniform(handle, r, in.data(), len, T(-1.0), T(1.0)); naiveScale(out_ref.data(), in.data(), params.scalar, len, stream); multiplyScalar(out.data(), in.data(), params.scalar, len, stream); handle.sync_stream(stream); diff --git a/cpp/test/linalg/norm.cu b/cpp/test/linalg/norm.cu index 8c54f7519a..83ded7d052 100644 --- a/cpp/test/linalg/norm.cu +++ b/cpp/test/linalg/norm.cu @@ -86,7 +86,7 @@ class RowNormTest : public ::testing::TestWithParam> { { raft::random::RngState r(params.seed); int rows = params.rows, cols = params.cols, len = rows * cols; - uniform(r, data.data(), len, T(-1.0), T(1.0), stream); + uniform(handle, r, data.data(), len, T(-1.0), T(1.0)); naiveRowNorm(dots_exp.data(), data.data(), cols, rows, params.type, params.do_sqrt, stream); if (params.do_sqrt) { auto fin_op = [] __device__(T in) { return raft::mySqrt(in); }; @@ -149,7 +149,7 @@ class ColNormTest : public ::testing::TestWithParam> { { raft::random::RngState r(params.seed); int rows = params.rows, cols = params.cols, len = rows * cols; - uniform(r, data.data(), len, T(-1.0), T(1.0), stream); + uniform(handle, r, data.data(), len, T(-1.0), T(1.0)); naiveColNorm(dots_exp.data(), data.data(), cols, rows, params.type, params.do_sqrt, stream); if (params.do_sqrt) { diff --git a/cpp/test/linalg/power.cu b/cpp/test/linalg/power.cu index 69ff7fb8c1..7c93b52d59 100644 --- a/cpp/test/linalg/power.cu +++ b/cpp/test/linalg/power.cu @@ -91,8 +91,8 @@ class PowerTest : public ::testing::TestWithParam> { in2.resize(len, stream); out_ref.resize(len, stream); out.resize(len, stream); - uniform(r, in1.data(), len, T(1.0), T(2.0), stream); - uniform(r, in2.data(), len, T(1.0), T(2.0), stream); + uniform(handle, r, in1.data(), len, T(1.0), T(2.0)); + uniform(handle, r, in2.data(), len, T(1.0), T(2.0)); naivePowerElem(out_ref.data(), in1.data(), in2.data(), len, stream); naivePowerScalar(out_ref.data(), out_ref.data(), T(2), len, stream); diff --git a/cpp/test/linalg/reduce.cu b/cpp/test/linalg/reduce.cu index e7d8275783..19d6130df9 100644 --- a/cpp/test/linalg/reduce.cu +++ b/cpp/test/linalg/reduce.cu @@ -83,7 +83,7 @@ class ReduceTest : public ::testing::TestWithParam int rows = params.rows, cols = params.cols; int len = rows * cols; outlen = params.alongRows ? rows : cols; - uniform(r, data.data(), len, InType(-1.0), InType(1.0), stream); + uniform(handle, r, data.data(), len, InType(-1.0), InType(1.0)); naiveReduction( dots_exp.data(), data.data(), cols, rows, params.rowMajor, params.alongRows, stream); diff --git a/cpp/test/linalg/reduce_cols_by_key.cu b/cpp/test/linalg/reduce_cols_by_key.cu index 9371e4158a..067e3a8b0e 100644 --- a/cpp/test/linalg/reduce_cols_by_key.cu +++ b/cpp/test/linalg/reduce_cols_by_key.cu @@ -72,22 +72,23 @@ class ReduceColsTest : public ::testing::TestWithParam> { { params = ::testing::TestWithParam>::GetParam(); raft::random::RngState r(params.seed); - RAFT_CUDA_TRY(cudaStreamCreate(&stream)); - auto nrows = params.rows; - auto ncols = params.cols; - auto nkeys = params.nkeys; + raft::handle_t handle; + auto stream = handle.get_stream(); + auto nrows = params.rows; + auto ncols = params.cols; + auto nkeys = params.nkeys; in.resize(nrows * ncols, stream); keys.resize(ncols, stream); out_ref.resize(nrows * nkeys, stream); out.resize(nrows * nkeys, stream); - uniform(r, in.data(), nrows * ncols, T(-1.0), T(1.0), stream); - uniformInt(r, keys.data(), ncols, 0u, params.nkeys, stream); + uniform(handle, r, in.data(), nrows * ncols, T(-1.0), T(1.0)); + uniformInt(handle, r, keys.data(), ncols, 0u, params.nkeys); naiveReduceColsByKey(in.data(), keys.data(), out_ref.data(), nrows, ncols, nkeys, stream); reduce_cols_by_key(in.data(), keys.data(), out.data(), nrows, ncols, nkeys, stream); raft::interruptible::synchronize(stream); } - void TearDown() override { RAFT_CUDA_TRY(cudaStreamDestroy(stream)); } + void TearDown() override {} protected: cudaStream_t stream = 0; diff --git a/cpp/test/linalg/reduce_rows_by_key.cu b/cpp/test/linalg/reduce_rows_by_key.cu index 499e2d96c5..5ebf6c5daa 100644 --- a/cpp/test/linalg/reduce_rows_by_key.cu +++ b/cpp/test/linalg/reduce_rows_by_key.cu @@ -106,14 +106,14 @@ class ReduceRowTest : public ::testing::TestWithParam> { int nobs = params.nobs; uint32_t cols = params.cols; uint32_t nkeys = params.nkeys; - uniform(r, in.data(), nobs * cols, T(0.0), T(2.0 / nobs), stream); - uniformInt(r_int, keys.data(), nobs, (uint32_t)0, nkeys, stream); + uniform(handle, r, in.data(), nobs * cols, T(0.0), T(2.0 / nobs)); + uniformInt(handle, r_int, keys.data(), nobs, (uint32_t)0, nkeys); rmm::device_uvector weight(0, stream); if (params.weighted) { weight.resize(nobs, stream); raft::random::RngState r(params.seed, raft::random::GeneratorType::GenPhilox); - uniform(r, weight.data(), nobs, T(1), params.max_weight, stream); + uniform(handle, r, weight.data(), nobs, T(1), params.max_weight); } naiveReduceRowsByKey(in.data(), diff --git a/cpp/test/linalg/rsvd.cu b/cpp/test/linalg/rsvd.cu index 2146526de5..568ab504a2 100644 --- a/cpp/test/linalg/rsvd.cu +++ b/cpp/test/linalg/rsvd.cu @@ -101,7 +101,7 @@ class RsvdTest : public ::testing::TestWithParam> { int n_redundant = n - n_informative; // Redundant cols int len_redundant = m * n_redundant; - normal(r, A.data(), len_informative, mu, sigma, stream); + normal(handle, r, A.data(), len_informative, mu, sigma); RAFT_CUDA_TRY(cudaMemcpyAsync(A.data() + len_informative, A.data(), len_redundant * sizeof(T), diff --git a/cpp/test/linalg/sqrt.cu b/cpp/test/linalg/sqrt.cu index aa681e07f3..b9fff65a80 100644 --- a/cpp/test/linalg/sqrt.cu +++ b/cpp/test/linalg/sqrt.cu @@ -69,7 +69,7 @@ class SqrtTest : public ::testing::TestWithParam> { in1.resize(len, stream); out_ref.resize(len, stream); out.resize(len, stream); - uniform(r, in1.data(), len, T(1.0), T(2.0), stream); + uniform(handle, r, in1.data(), len, T(1.0), T(2.0)); naiveSqrtElem(out_ref.data(), in1.data(), len); diff --git a/cpp/test/linalg/strided_reduction.cu b/cpp/test/linalg/strided_reduction.cu index 1b72b50739..a1b6f9de0d 100644 --- a/cpp/test/linalg/strided_reduction.cu +++ b/cpp/test/linalg/strided_reduction.cu @@ -56,8 +56,7 @@ class stridedReductionTest : public ::testing::TestWithParam> { { raft::random::RngState r(params.seed); int len = params.len; - uniform(r, in1.data(), len, T(-1.0), T(1.0), stream); - uniform(r, in2.data(), len, T(-1.0), T(1.0), stream); + uniform(handle, r, in1.data(), len, T(-1.0), T(1.0)); + uniform(handle, r, in2.data(), len, T(-1.0), T(1.0)); naiveSubtractElem(out_ref.data(), in1.data(), in2.data(), len, stream); naiveSubtractScalar(out_ref.data(), out_ref.data(), T(1), len, stream); diff --git a/cpp/test/linalg/ternary_op.cu b/cpp/test/linalg/ternary_op.cu index a453e9effe..a34274a412 100644 --- a/cpp/test/linalg/ternary_op.cu +++ b/cpp/test/linalg/ternary_op.cu @@ -57,11 +57,11 @@ class ternaryOpTest : public ::testing::TestWithParam> { rmm::device_uvector in2(len, stream); rmm::device_uvector in3(len, stream); - fill(rng, out_add_ref.data(), len, T(6.0), stream); - fill(rng, out_mul_ref.data(), len, T(6.0), stream); - fill(rng, in1.data(), len, T(1.0), stream); - fill(rng, in2.data(), len, T(2.0), stream); - fill(rng, in3.data(), len, T(3.0), stream); + fill(handle, rng, out_add_ref.data(), len, T(6.0)); + fill(handle, rng, out_mul_ref.data(), len, T(6.0)); + fill(handle, rng, in1.data(), len, T(1.0)); + fill(handle, rng, in2.data(), len, T(2.0)); + fill(handle, rng, in3.data(), len, T(3.0)); auto add = [] __device__(T a, T b, T c) { return a + b + c; }; auto mul = [] __device__(T a, T b, T c) { return a * b * c; }; diff --git a/cpp/test/linalg/unary_op.cu b/cpp/test/linalg/unary_op.cu index 74a49fd58f..8d4725b72f 100644 --- a/cpp/test/linalg/unary_op.cu +++ b/cpp/test/linalg/unary_op.cu @@ -58,7 +58,7 @@ class UnaryOpTest : public ::testing::TestWithParam blob(workSizeElems, stream); - uniform(r, blob.data(), workSizeElems, T(-1.0), T(1.0), stream); + uniform(handle, r, blob.data(), workSizeElems, T(-1.0), T(1.0)); return blob; } diff --git a/cpp/test/matrix/math.cu b/cpp/test/matrix/math.cu index 18e7d0efea..30a6ed7083 100644 --- a/cpp/test/matrix/math.cu +++ b/cpp/test/matrix/math.cu @@ -141,10 +141,10 @@ class MathTest : public ::testing::TestWithParam> { T out_ratio_ref_h[4] = {0.125, 0.25, 0.25, 0.375}; update_device(out_ratio_ref.data(), out_ratio_ref_h, 4, stream); - uniform(r, in_power.data(), len, T(-1.0), T(1.0), stream); - uniform(r, in_sqrt.data(), len, T(0.0), T(1.0), stream); + uniform(handle, r, in_power.data(), len, T(-1.0), T(1.0)); + uniform(handle, r, in_sqrt.data(), len, T(0.0), T(1.0)); // uniform(r, in_ratio, len, T(0.0), T(1.0)); - uniform(r, in_sign_flip.data(), len, T(-100.0), T(100.0), stream); + uniform(handle, r, in_sign_flip.data(), len, T(-100.0), T(100.0)); naivePower(in_power.data(), out_power_ref.data(), len, stream); power(in_power.data(), len, stream); diff --git a/cpp/test/matrix/matrix.cu b/cpp/test/matrix/matrix.cu index 77144213f3..1b6ac57fc4 100644 --- a/cpp/test/matrix/matrix.cu +++ b/cpp/test/matrix/matrix.cu @@ -55,7 +55,7 @@ class MatrixTest : public ::testing::TestWithParam> { { raft::random::RngState r(params.seed); int len = params.n_row * params.n_col; - uniform(r, in1.data(), len, T(-1.0), T(1.0), stream); + uniform(handle, r, in1.data(), len, T(-1.0), T(1.0)); copy(in1.data(), in2.data(), params.n_row, params.n_col, stream); // copy(in1, in1_revr, params.n_row, params.n_col); diff --git a/cpp/test/random/make_blobs.cu b/cpp/test/random/make_blobs.cu index 7390f9dbb3..8540e755a4 100644 --- a/cpp/test/random/make_blobs.cu +++ b/cpp/test/random/make_blobs.cu @@ -103,7 +103,7 @@ class MakeBlobsTest : public ::testing::TestWithParam> { RAFT_CUDA_TRY(cudaMemsetAsync(lens.data(), 0, lens.extent(0) * sizeof(int), stream)); RAFT_CUDA_TRY(cudaMemsetAsync(mean_var.data(), 0, mean_var.size() * sizeof(T), stream)); - uniform(r, mu_vec.data(), params.cols * params.n_clusters, T(-10.0), T(10.0), stream); + uniform(handle, r, mu_vec.data(), params.cols * params.n_clusters, T(-10.0), T(10.0)); make_blobs(handle, data.view(), diff --git a/cpp/test/random/permute.cu b/cpp/test/random/permute.cu index 9a407367c7..84dc5970c4 100644 --- a/cpp/test/random/permute.cu +++ b/cpp/test/random/permute.cu @@ -45,7 +45,8 @@ class PermTest : public ::testing::TestWithParam> { void SetUp() override { - RAFT_CUDA_TRY(cudaStreamCreate(&stream)); + raft::handle_t h; + stream = h.get_stream(); params = ::testing::TestWithParam>::GetParam(); // forcefully set needPerms, since we need it for unit-testing! if (params.needShuffle) { params.needPerms = true; } @@ -64,10 +65,10 @@ class PermTest : public ::testing::TestWithParam> { out.resize(len, stream); in_ptr = in.data(); out_ptr = out.data(); - uniform(r, in_ptr, len, T(-1.0), T(1.0), stream); + uniform(h, r, in_ptr, len, T(-1.0), T(1.0)); } permute(outPerms_ptr, out_ptr, in_ptr, D, N, params.rowMajor, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + h.sync_stream(); } void TearDown() override { RAFT_CUDA_TRY(cudaStreamDestroy(stream)); } diff --git a/cpp/test/random/rng.cu b/cpp/test/random/rng.cu index 12479c6fee..58107293ee 100644 --- a/cpp/test/random/rng.cu +++ b/cpp/test/random/rng.cu @@ -294,7 +294,6 @@ TEST(Rng, MeanError) raft::handle_t handle; auto stream = handle.get_stream(); - RAFT_CUDA_TRY(cudaStreamCreate(&stream)); rmm::device_uvector data(len, stream); rmm::device_uvector mean_result(num_experiments, stream); @@ -335,7 +334,6 @@ TEST(Rng, MeanError) ASSERT_TRUE((diff_expected_vs_measured_mean_error / d_std_of_mean_analytical < 0.5)) << "Failed with seed: " << seed << "\nrtype: " << rtype; } - RAFT_CUDA_TRY(cudaStreamDestroy(stream)); // std::cout << "mean_res:" << h_mean_result << "\n"; } @@ -446,7 +444,7 @@ class RngNormalTableTest : public ::testing::TestWithParam> { switch (params.type) { case RNG_Uniform: - uniformInt(r, data.data(), params.len, params.start, params.end, stream); + uniformInt(handle, r, data.data(), params.len, params.start, params.end); break; }; static const int threads = 128; diff --git a/cpp/test/random/sample_without_replacement.cu b/cpp/test/random/sample_without_replacement.cu index b558e9e879..63f0b20df4 100644 --- a/cpp/test/random/sample_without_replacement.cu +++ b/cpp/test/random/sample_without_replacement.cu @@ -63,13 +63,13 @@ class SWoRTest : public ::testing::TestWithParam> { { RngState r(params.seed, params.gtype); h_outIdx.resize(params.sampledLen); - uniform(r, in.data(), params.len, T(-1.0), T(1.0), stream); - uniform(r, wts.data(), params.len, T(1.0), T(2.0), stream); + uniform(handle, r, in.data(), params.len, T(-1.0), T(1.0)); + uniform(handle, r, wts.data(), params.len, T(1.0), T(2.0)); if (params.largeWeightIndex >= 0) { update_device(wts.data() + params.largeWeightIndex, ¶ms.largeWeight, 1, stream); } sampleWithoutReplacement( - r, out.data(), outIdx.data(), in.data(), wts.data(), params.sampledLen, params.len, stream); + handle, r, out.data(), outIdx.data(), in.data(), wts.data(), params.sampledLen, params.len); update_host(&(h_outIdx[0]), outIdx.data(), params.sampledLen, stream); handle.sync_stream(stream); } diff --git a/cpp/test/sparse/filter.cu b/cpp/test/sparse/filter.cu index 11a63e67a1..c22fe09134 100644 --- a/cpp/test/sparse/filter.cu +++ b/cpp/test/sparse/filter.cu @@ -50,16 +50,16 @@ const std::vector> inputsf = {{5, 10, 5, 1234ULL}}; typedef SparseFilterTests COORemoveZeros; TEST_P(COORemoveZeros, Result) { - cudaStream_t stream; - cudaStreamCreate(&stream); - params = ::testing::TestWithParam>::GetParam(); + raft::handle_t h; + auto stream = h.get_stream(); + params = ::testing::TestWithParam>::GetParam(); float* in_h_vals = new float[params.nnz]; COO in(stream, params.nnz, 5, 5); raft::random::RngState r(params.seed); - uniform(r, in.vals(), params.nnz, float(-1.0), float(1.0), stream); + uniform(h, r, in.vals(), params.nnz, float(-1.0), float(1.0)); raft::update_host(in_h_vals, in.vals(), params.nnz, stream); @@ -102,7 +102,6 @@ TEST_P(COORemoveZeros, Result) ASSERT_TRUE(raft::devArrMatch(out_ref.cols(), out.cols(), 2, raft::Compare())); ASSERT_TRUE(raft::devArrMatch(out_ref.vals(), out.vals(), 2, raft::Compare())); - RAFT_CUDA_TRY(cudaStreamDestroy(stream)); free(out_vals_ref_h); delete[] in_h_rows; diff --git a/cpp/test/sparse/sort.cu b/cpp/test/sparse/sort.cu index 462ba5fb80..ecea5344e7 100644 --- a/cpp/test/sparse/sort.cu +++ b/cpp/test/sparse/sort.cu @@ -50,15 +50,15 @@ TEST_P(COOSort, Result) { params = ::testing::TestWithParam>::GetParam(); raft::random::RngState r(params.seed); - cudaStream_t stream; - RAFT_CUDA_TRY(cudaStreamCreate(&stream)); + raft::handle_t h; + auto stream = h.get_stream(); rmm::device_uvector in_rows(params.nnz, stream); rmm::device_uvector in_cols(params.nnz, stream); rmm::device_uvector verify(params.nnz, stream); rmm::device_uvector in_vals(params.nnz, stream); - uniform(r, in_vals.data(), params.nnz, float(-1.0), float(1.0), stream); + uniform(h, r, in_vals.data(), params.nnz, float(-1.0), float(1.0)); int* in_rows_h = (int*)malloc(params.nnz * sizeof(int)); int* in_cols_h = (int*)malloc(params.nnz * sizeof(int)); @@ -84,8 +84,6 @@ TEST_P(COOSort, Result) delete[] in_rows_h; delete[] in_cols_h; delete[] verify_h; - - RAFT_CUDA_TRY(cudaStreamDestroy(stream)); } INSTANTIATE_TEST_CASE_P(SparseSortTest, COOSort, ::testing::ValuesIn(inputsf)); diff --git a/cpp/test/spatial/fused_l2_knn.cu b/cpp/test/spatial/fused_l2_knn.cu index 75d9fcc622..bb0b3a63d7 100644 --- a/cpp/test/spatial/fused_l2_knn.cu +++ b/cpp/test/spatial/fused_l2_knn.cu @@ -165,8 +165,8 @@ class FusedL2KNNTest : public ::testing::TestWithParam { unsigned long long int seed = 1234ULL; raft::random::RngState r(seed); - uniform(r, database.data(), num_db_vecs * dim, T(-1.0), T(1.0), stream_); - uniform(r, search_queries.data(), num_queries * dim, T(-1.0), T(1.0), stream_); + uniform(handle_, r, database.data(), num_db_vecs * dim, T(-1.0), T(1.0)); + uniform(handle_, r, search_queries.data(), num_queries * dim, T(-1.0), T(1.0)); } void launchFaissBfknn() diff --git a/cpp/test/spatial/selection.cu b/cpp/test/spatial/selection.cu index a132042ca2..862fad56b4 100644 --- a/cpp/test/spatial/selection.cu +++ b/cpp/test/spatial/selection.cu @@ -321,10 +321,11 @@ struct with_ref { auto algo = std::get<1>(ps); std::vector dists(spec.input_len * spec.n_inputs); - auto s = rmm::cuda_stream_default; + raft::handle_t handle; + auto s = handle.get_stream(); rmm::device_uvector dists_d(spec.input_len * spec.n_inputs, s); raft::random::RngState r(42); - normal(r, dists_d.data(), dists_d.size(), KeyT(10.0), KeyT(100.0), s); + normal(handle, r, dists_d.data(), dists_d.size(), KeyT(10.0), KeyT(100.0)); update_host(dists.data(), dists_d.data(), dists_d.size(), s); s.synchronize(); diff --git a/cpp/test/stats/cov.cu b/cpp/test/stats/cov.cu index 52e2c965c7..d9cc3ec8be 100644 --- a/cpp/test/stats/cov.cu +++ b/cpp/test/stats/cov.cu @@ -66,7 +66,7 @@ class CovTest : public ::testing::TestWithParam> { mean_act.resize(cols, stream); cov_act.resize(cols * cols, stream); - normal(r, data.data(), len, params.mean, var, stream); + normal(handle, r, data.data(), len, params.mean, var); raft::stats::mean( mean_act.data(), data.data(), cols, rows, params.sample, params.rowMajor, stream); cov(handle, diff --git a/cpp/test/stats/dispersion.cu b/cpp/test/stats/dispersion.cu index 8ab5b2ade8..b8fd9dfe80 100644 --- a/cpp/test/stats/dispersion.cu +++ b/cpp/test/stats/dispersion.cu @@ -51,13 +51,13 @@ class DispersionTest : public ::testing::TestWithParam> { params = ::testing::TestWithParam>::GetParam(); raft::random::RngState r(params.seed); int len = params.clusters * params.dim; - RAFT_CUDA_TRY(cudaStreamCreate(&stream)); + stream = handle.get_stream(); rmm::device_uvector data(len, stream); rmm::device_uvector counts(params.clusters, stream); exp_mean.resize(params.dim, stream); act_mean.resize(params.dim, stream); - uniform(r, data.data(), len, (T)-1.0, (T)1.0, stream); - uniformInt(r, counts.data(), params.clusters, 1, 100, stream); + uniform(handle, r, data.data(), len, (T)-1.0, (T)1.0); + uniformInt(handle, r, counts.data(), params.clusters, 1, 100); std::vector h_counts(params.clusters, 0); raft::update_host(&(h_counts[0]), counts.data(), params.clusters, stream); npoints = 0; @@ -89,10 +89,9 @@ class DispersionTest : public ::testing::TestWithParam> { raft::interruptible::synchronize(stream); } - void TearDown() override { RAFT_CUDA_TRY(cudaStreamDestroy(stream)); } - protected: DispersionInputs params; + raft::handle_t handle; rmm::device_uvector exp_mean, act_mean; cudaStream_t stream = 0; int npoints; diff --git a/cpp/test/stats/histogram.cu b/cpp/test/stats/histogram.cu index e687283b62..caf87b2581 100644 --- a/cpp/test/stats/histogram.cu +++ b/cpp/test/stats/histogram.cu @@ -68,13 +68,14 @@ class HistTest : public ::testing::TestWithParam { { params = ::testing::TestWithParam::GetParam(); raft::random::RngState r(params.seed); - RAFT_CUDA_TRY(cudaStreamCreate(&stream)); + raft::handle_t h; + stream = h.get_stream(); int len = params.nrows * params.ncols; in.resize(len, stream); if (params.isNormal) { - normalInt(r, in.data(), len, params.start, params.end, stream); + normalInt(h, r, in.data(), len, params.start, params.end); } else { - uniformInt(r, in.data(), len, params.start, params.end, stream); + uniformInt(h, r, in.data(), len, params.start, params.end); } bins.resize(params.nbins * params.ncols, stream); ref_bins.resize(params.nbins * params.ncols, stream); diff --git a/cpp/test/stats/mean.cu b/cpp/test/stats/mean.cu index 017b062872..b7f24d5642 100644 --- a/cpp/test/stats/mean.cu +++ b/cpp/test/stats/mean.cu @@ -58,7 +58,7 @@ class MeanTest : public ::testing::TestWithParam> { { raft::random::RngState r(params.seed); int len = rows * cols; - normal(r, data.data(), len, params.mean, (T)1.0, stream); + normal(handle, r, data.data(), len, params.mean, (T)1.0); meanSGtest(data.data(), stream); } diff --git a/cpp/test/stats/mean_center.cu b/cpp/test/stats/mean_center.cu index f72c47af75..3d92a52fb4 100644 --- a/cpp/test/stats/mean_center.cu +++ b/cpp/test/stats/mean_center.cu @@ -59,7 +59,7 @@ class MeanCenterTest : public ::testing::TestWithParam> { void SetUp() override { random::RngState r(params.seed); - normal(r, data.data(), params.cols * params.rows, params.mean, params.stddev, stream); + normal(handle, r, data.data(), params.cols * params.rows, params.mean, params.stddev); meanvar(mean_act.data(), vars_act.data(), data.data(), diff --git a/cpp/test/stats/minmax.cu b/cpp/test/stats/minmax.cu index a52a60a7a8..532932b6ba 100644 --- a/cpp/test/stats/minmax.cu +++ b/cpp/test/stats/minmax.cu @@ -90,23 +90,23 @@ __global__ void nanKernel(T* data, const bool* mask, int len, T nan) template class MinMaxTest : public ::testing::TestWithParam> { protected: - MinMaxTest() : minmax_act(0, stream), minmax_ref(0, stream) {} + MinMaxTest() : minmax_act(0, handle.get_stream()), minmax_ref(0, handle.get_stream()) {} void SetUp() override { - params = ::testing::TestWithParam>::GetParam(); + auto stream = handle.get_stream(); + params = ::testing::TestWithParam>::GetParam(); raft::random::RngState r(params.seed); int len = params.rows * params.cols; - RAFT_CUDA_TRY(cudaStreamCreate(&stream)); rmm::device_uvector data(len, stream); rmm::device_uvector mask(len, stream); minmax_act.resize(2 * params.cols, stream); minmax_ref.resize(2 * params.cols, stream); - normal(r, data.data(), len, (T)0.0, (T)1.0, stream); + normal(handle, r, data.data(), len, (T)0.0, (T)1.0); T nan_prob = 0.01; - bernoulli(r, mask.data(), len, nan_prob, stream); + bernoulli(handle, r, mask.data(), len, nan_prob); const int TPB = 256; nanKernel<<>>( data.data(), mask.data(), len, std::numeric_limits::quiet_NaN()); @@ -130,10 +130,10 @@ class MinMaxTest : public ::testing::TestWithParam> { } protected: + raft::handle_t handle; MinMaxInputs params; rmm::device_uvector minmax_act; rmm::device_uvector minmax_ref; - cudaStream_t stream = 0; }; const std::vector> inputsf = {{0.00001f, 1024, 32, 1234ULL}, diff --git a/cpp/test/stats/stddev.cu b/cpp/test/stats/stddev.cu index a49dd8a165..0521209e98 100644 --- a/cpp/test/stats/stddev.cu +++ b/cpp/test/stats/stddev.cu @@ -64,7 +64,7 @@ class StdDevTest : public ::testing::TestWithParam> { mean_act.resize(cols, stream); stddev_act.resize(cols, stream); vars_act.resize(cols, stream); - normal(r, data.data(), len, params.mean, params.stddev, stream); + normal(handle, r, data.data(), len, params.mean, params.stddev); stdVarSGtest(data.data(), stream); handle.sync_stream(stream); } diff --git a/cpp/test/stats/weighted_mean.cu b/cpp/test/stats/weighted_mean.cu index 373dc99ba7..ca0078cebb 100644 --- a/cpp/test/stats/weighted_mean.cu +++ b/cpp/test/stats/weighted_mean.cu @@ -70,8 +70,7 @@ class RowWeightedMeanTest : public ::testing::TestWithParam>::GetParam(); raft::random::RngState r(params.seed); int rows = params.M, cols = params.N, len = rows * cols; - cudaStream_t stream = 0; - RAFT_CUDA_TRY(cudaStreamCreate(&stream)); + auto stream = handle.get_stream(); // device-side data din.resize(len); dweights.resize(cols); @@ -79,8 +78,8 @@ class RowWeightedMeanTest : public ::testing::TestWithParam hin = din; @@ -96,12 +95,12 @@ class RowWeightedMeanTest : public ::testing::TestWithParam params; thrust::host_vector hin, hweights; thrust::device_vector din, dweights, dexp, dact; @@ -136,8 +135,7 @@ class ColWeightedMeanTest : public ::testing::TestWithParam hin = din; @@ -162,12 +160,12 @@ class ColWeightedMeanTest : public ::testing::TestWithParam params; thrust::host_vector hin, hweights; thrust::device_vector din, dweights, dexp, dact; @@ -180,11 +178,10 @@ class WeightedMeanTest : public ::testing::TestWithParam> { params = ::testing::TestWithParam>::GetParam(); raft::random::RngState r(params.seed); + auto stream = handle.get_stream(); int rows = params.M, cols = params.N, len = rows * cols; - auto weight_size = params.along_rows ? cols : rows; - auto mean_size = params.along_rows ? rows : cols; - cudaStream_t stream = 0; - RAFT_CUDA_TRY(cudaStreamCreate(&stream)); + auto weight_size = params.along_rows ? cols : rows; + auto mean_size = params.along_rows ? rows : cols; // device-side data din.resize(len); dweights.resize(weight_size); @@ -192,8 +189,8 @@ class WeightedMeanTest : public ::testing::TestWithParam> dact.resize(mean_size); // create random matrix and weights - uniform(r, din.data().get(), len, T(-1.0), T(1.0), stream); - uniform(r, dweights.data().get(), weight_size, T(-1.0), T(1.0), stream); + uniform(handle, r, din.data().get(), len, T(-1.0), T(1.0)); + uniform(handle, r, dweights.data().get(), weight_size, T(-1.0), T(1.0)); // host-side data thrust::host_vector hin = din; @@ -219,12 +216,12 @@ class WeightedMeanTest : public ::testing::TestWithParam> // adjust tolerance to account for round-off accumulation params.tolerance *= params.N; - RAFT_CUDA_TRY(cudaStreamDestroy(stream)); } void TearDown() override {} protected: + raft::handle_t handle; WeightedMeanInputs params; thrust::host_vector hin, hweights; thrust::device_vector din, dweights, dexp, dact; From b327692d3b3ff6d3993e75002ed01c3f8a454431 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 3 May 2022 19:15:08 -0400 Subject: [PATCH 04/10] Final updates --- cpp/test/random/permute.cu | 28 +++++++++++++--------------- cpp/test/stats/histogram.cu | 20 ++++++++++---------- 2 files changed, 23 insertions(+), 25 deletions(-) diff --git a/cpp/test/random/permute.cu b/cpp/test/random/permute.cu index 84dc5970c4..6b0ee0457e 100644 --- a/cpp/test/random/permute.cu +++ b/cpp/test/random/permute.cu @@ -41,21 +41,21 @@ template template class PermTest : public ::testing::TestWithParam> { protected: - PermTest() : in(0, stream), out(0, stream), outPerms(0, stream) {} + PermTest() + : in(0, handle.get_stream()), out(0, handle.get_stream()), outPerms(0, handle.get_stream()) + { + } void SetUp() override { - raft::handle_t h; - stream = h.get_stream(); - params = ::testing::TestWithParam>::GetParam(); + auto stream = handle.get_stream(); + params = ::testing::TestWithParam>::GetParam(); // forcefully set needPerms, since we need it for unit-testing! if (params.needShuffle) { params.needPerms = true; } raft::random::RngState r(params.seed); - int N = params.N; - int D = params.D; - int len = N * D; - cudaStream_t stream = 0; - RAFT_CUDA_TRY(cudaStreamCreate(&stream)); + int N = params.N; + int D = params.D; + int len = N * D; if (params.needPerms) { outPerms.resize(N, stream); outPerms_ptr = outPerms.data(); @@ -65,22 +65,20 @@ class PermTest : public ::testing::TestWithParam> { out.resize(len, stream); in_ptr = in.data(); out_ptr = out.data(); - uniform(h, r, in_ptr, len, T(-1.0), T(1.0)); + uniform(handle, r, in_ptr, len, T(-1.0), T(1.0)); } permute(outPerms_ptr, out_ptr, in_ptr, D, N, params.rowMajor, stream); - h.sync_stream(); + handle.sync_stream(); } - void TearDown() override { RAFT_CUDA_TRY(cudaStreamDestroy(stream)); } - protected: + raft::handle_t handle; PermInputs params; rmm::device_uvector in, out; T* in_ptr = nullptr; T* out_ptr = nullptr; rmm::device_uvector outPerms; - int* outPerms_ptr = nullptr; - cudaStream_t stream = 0; + int* outPerms_ptr = nullptr; }; template diff --git a/cpp/test/stats/histogram.cu b/cpp/test/stats/histogram.cu index caf87b2581..f09c01c84a 100644 --- a/cpp/test/stats/histogram.cu +++ b/cpp/test/stats/histogram.cu @@ -62,20 +62,22 @@ struct HistInputs { class HistTest : public ::testing::TestWithParam { protected: - HistTest() : in(0, stream), bins(0, stream), ref_bins(0, stream) {} + HistTest() + : in(0, handle.get_stream()), bins(0, handle.get_stream()), ref_bins(0, handle.get_stream()) + { + } void SetUp() override { params = ::testing::TestWithParam::GetParam(); raft::random::RngState r(params.seed); - raft::handle_t h; - stream = h.get_stream(); - int len = params.nrows * params.ncols; + auto stream = handle.get_stream(); + int len = params.nrows * params.ncols; in.resize(len, stream); if (params.isNormal) { - normalInt(h, r, in.data(), len, params.start, params.end); + normalInt(handle, r, in.data(), len, params.start, params.end); } else { - uniformInt(h, r, in.data(), len, params.start, params.end); + uniformInt(handle, r, in.data(), len, params.start, params.end); } bins.resize(params.nbins * params.ncols, stream); ref_bins.resize(params.nbins * params.ncols, stream); @@ -84,13 +86,11 @@ class HistTest : public ::testing::TestWithParam { naiveHist(ref_bins.data(), params.nbins, in.data(), params.nrows, params.ncols, stream); histogram( params.type, bins.data(), params.nbins, in.data(), params.nrows, params.ncols, stream); - raft::interruptible::synchronize(stream); + handle.sync_stream(); } - void TearDown() override { RAFT_CUDA_TRY(cudaStreamDestroy(stream)); } - protected: - cudaStream_t stream = 0; + raft::handle_t handle; HistInputs params; rmm::device_uvector in, bins, ref_bins; }; From b576104357be0524b58671c1d5663d9f227f7970 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 4 May 2022 09:55:16 -0400 Subject: [PATCH 05/10] Updating rng for raft bench --- cpp/bench/common/benchmark.hpp | 4 ++-- cpp/bench/random/permute.cu | 3 ++- cpp/bench/random/rng.cu | 18 +++++++++--------- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/cpp/bench/common/benchmark.hpp b/cpp/bench/common/benchmark.hpp index 93814ead44..85c33c6e72 100644 --- a/cpp/bench/common/benchmark.hpp +++ b/cpp/bench/common/benchmark.hpp @@ -102,13 +102,13 @@ struct cuda_event_timer { /** Main fixture to be inherited and used by all other c++ benchmarks */ class fixture { private: - rmm::cuda_stream stream_owner_{}; rmm::device_buffer scratch_buf_; public: + raft::handle_t handle; rmm::cuda_stream_view stream; - fixture() : stream{stream_owner_.view()} + fixture() : stream{handle.get_stream()} { int l2_cache_size = 0; int device_id = 0; diff --git a/cpp/bench/random/permute.cu b/cpp/bench/random/permute.cu index 14622bb0ef..a72eca3f87 100644 --- a/cpp/bench/random/permute.cu +++ b/cpp/bench/random/permute.cu @@ -37,7 +37,7 @@ struct permute : public fixture { in(p.rows * p.cols, stream) { raft::random::RngState r(123456ULL); - uniform(r, in.data(), p.rows, T(-1.0), T(1.0), stream); + uniform(handle, r, in.data(), p.rows, T(-1.0), T(1.0)); } void run_benchmark(::benchmark::State& state) override @@ -50,6 +50,7 @@ struct permute : public fixture { } private: + raft::handle_t handle; permute_inputs params; rmm::device_uvector out, in; rmm::device_uvector perms; diff --git a/cpp/bench/random/rng.cu b/cpp/bench/random/rng.cu index 88d724aef5..dab5b119d4 100644 --- a/cpp/bench/random/rng.cu +++ b/cpp/bench/random/rng.cu @@ -51,23 +51,23 @@ struct rng : public fixture { raft::random::RngState r(123456ULL, params.gtype); loop_on_state(state, [this, &r]() { switch (params.type) { - case RNG_Normal: normal(r, ptr.data(), params.len, params.start, params.end, stream); break; + case RNG_Normal: normal(handle, r, ptr.data(), params.len, params.start, params.end); break; case RNG_LogNormal: - lognormal(r, ptr.data(), params.len, params.start, params.end, stream); + lognormal(handle, r, ptr.data(), params.len, params.start, params.end); break; case RNG_Uniform: - uniform(r, ptr.data(), params.len, params.start, params.end, stream); + uniform(handle, r, ptr.data(), params.len, params.start, params.end); break; - case RNG_Gumbel: gumbel(r, ptr.data(), params.len, params.start, params.end, stream); break; + case RNG_Gumbel: gumbel(handle, r, ptr.data(), params.len, params.start, params.end); break; case RNG_Logistic: - logistic(r, ptr.data(), params.len, params.start, params.end, stream); + logistic(handle, r, ptr.data(), params.len, params.start, params.end); break; - case RNG_Exp: exponential(r, ptr.data(), params.len, params.start, stream); break; - case RNG_Rayleigh: rayleigh(r, ptr.data(), params.len, params.start, stream); break; + case RNG_Exp: exponential(handle, r, ptr.data(), params.len, params.start); break; + case RNG_Rayleigh: rayleigh(handle, r, ptr.data(), params.len, params.start); break; case RNG_Laplace: - laplace(r, ptr.data(), params.len, params.start, params.end, stream); + laplace(handle, r, ptr.data(), params.len, params.start, params.end); break; - case RNG_Fill: fill(r, ptr.data(), params.len, params.start, stream); break; + case RNG_Fill: fill(handle, r, ptr.data(), params.len, params.start); break; }; }); } From ad1586dd7f87912605d9a3b4ab5bb6d5c86d16ba Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 4 May 2022 09:57:05 -0400 Subject: [PATCH 06/10] Removing empty teardown methods --- cpp/test/linalg/reduce_cols_by_key.cu | 2 -- cpp/test/stats/weighted_mean.cu | 6 ------ 2 files changed, 8 deletions(-) diff --git a/cpp/test/linalg/reduce_cols_by_key.cu b/cpp/test/linalg/reduce_cols_by_key.cu index 067e3a8b0e..6682f54ace 100644 --- a/cpp/test/linalg/reduce_cols_by_key.cu +++ b/cpp/test/linalg/reduce_cols_by_key.cu @@ -88,8 +88,6 @@ class ReduceColsTest : public ::testing::TestWithParam> { raft::interruptible::synchronize(stream); } - void TearDown() override {} - protected: cudaStream_t stream = 0; ReduceColsInputs params; diff --git a/cpp/test/stats/weighted_mean.cu b/cpp/test/stats/weighted_mean.cu index ca0078cebb..9f3e6a79f6 100644 --- a/cpp/test/stats/weighted_mean.cu +++ b/cpp/test/stats/weighted_mean.cu @@ -97,8 +97,6 @@ class RowWeightedMeanTest : public ::testing::TestWithParam params; @@ -162,8 +160,6 @@ class ColWeightedMeanTest : public ::testing::TestWithParam params; @@ -218,8 +214,6 @@ class WeightedMeanTest : public ::testing::TestWithParam> params.tolerance *= params.N; } - void TearDown() override {} - protected: raft::handle_t handle; WeightedMeanInputs params; From a5b38e2d422ba652555dd9f72b3305a2b07fdde4 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 4 May 2022 10:42:52 -0400 Subject: [PATCH 07/10] Including handle --- cpp/bench/common/benchmark.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/bench/common/benchmark.hpp b/cpp/bench/common/benchmark.hpp index 85c33c6e72..de34cf4f57 100644 --- a/cpp/bench/common/benchmark.hpp +++ b/cpp/bench/common/benchmark.hpp @@ -18,6 +18,7 @@ #include +#include #include #include From 11f8925ef15305c9aeb270e759ebc01d37253536 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 4 May 2022 14:54:15 -0400 Subject: [PATCH 08/10] Fixing remaining build issues --- cpp/bench/spatial/fused_l2_nn.cu | 4 ++-- cpp/bench/spatial/selection.cu | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/bench/spatial/fused_l2_nn.cu b/cpp/bench/spatial/fused_l2_nn.cu index 3062708f50..dc3b507fbf 100644 --- a/cpp/bench/spatial/fused_l2_nn.cu +++ b/cpp/bench/spatial/fused_l2_nn.cu @@ -46,8 +46,8 @@ struct fused_l2_nn : public fixture { raft::handle_t handle{stream}; raft::random::RngState r(123456ULL); - uniform(r, x.data(), p.m * p.k, T(-1.0), T(1.0), stream); - uniform(r, y.data(), p.n * p.k, T(-1.0), T(1.0), stream); + uniform(handle, r, x.data(), p.m * p.k, T(-1.0), T(1.0)); + uniform(handle, r, y.data(), p.n * p.k, T(-1.0), T(1.0)); raft::linalg::rowNorm(xn.data(), x.data(), p.k, p.m, raft::linalg::L2Norm, true, stream); raft::linalg::rowNorm(yn.data(), y.data(), p.k, p.n, raft::linalg::L2Norm, true, stream); raft::distance::initialize, int>( diff --git a/cpp/bench/spatial/selection.cu b/cpp/bench/spatial/selection.cu index 2f2c995dd5..c3a2bc6d3d 100644 --- a/cpp/bench/spatial/selection.cu +++ b/cpp/bench/spatial/selection.cu @@ -47,7 +47,7 @@ struct selection : public fixture { { raft::sparse::iota_fill(in_ids_.data(), IdxT(p.n_inputs), IdxT(p.input_len), stream); raft::random::RngState state{42}; - raft::random::uniform(state, in_dists_.data(), in_dists_.size(), KeyT(-1.0), KeyT(1.0), stream); + raft::random::uniform(handle, state, in_dists_.data(), in_dists_.size(), KeyT(-1.0), KeyT(1.0)); } void run_benchmark(::benchmark::State& state) override From 61eb416b2c28751739582eb510fc4b9c91112ebb Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 4 May 2022 16:12:55 -0400 Subject: [PATCH 09/10] Updating rng docs to fix ci --- cpp/include/raft/random/rng.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/random/rng.cuh b/cpp/include/raft/random/rng.cuh index faa18cbcd9..1ded65538e 100644 --- a/cpp/include/raft/random/rng.cuh +++ b/cpp/include/raft/random/rng.cuh @@ -76,7 +76,7 @@ void uniformInt(const raft::handle_t& handle, * @tparam LenType data type used to represent length of the arrays * @param[in] handle raft handle for resource management * @param[in] rng_state random number generator state - * @param[out[ ptr the output array + * @param[out] ptr the output array * @param[in] len the number of elements in the output * @param[in] mu mean of the distribution * @param[in] sigma std-dev of the distribution @@ -99,7 +99,7 @@ void normal(const raft::handle_t& handle, * @tparam LenType data type used to represent length of the arrays * @param[in] handle raft handle for resource management * @param[in] rng_state random number generator state - * @param[out[ ptr the output array + * @param[out] ptr the output array * @param[in] len the number of elements in the output * @param[in] mu mean of the distribution * @param[in] sigma std-dev of the distribution From fe5f0ab9a4d4b52d107d5f481771f13633e5e169 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 5 May 2022 14:43:32 -0400 Subject: [PATCH 10/10] Adding old functions back just for compatibility. Will remove them. --- cpp/include/raft/random/rng.cuh | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/cpp/include/raft/random/rng.cuh b/cpp/include/raft/random/rng.cuh index 1ded65538e..85d9abe263 100644 --- a/cpp/include/raft/random/rng.cuh +++ b/cpp/include/raft/random/rng.cuh @@ -403,6 +403,23 @@ void affine_transform_params(RngState const& rng_state, IdxT n, IdxT& a, IdxT& b // without the macro, clang-format seems to go insane #define DEPR [[deprecated("Use 'RngState' with the new flat functions instead")]] +using detail::bernoulli; +using detail::exponential; +using detail::fill; +using detail::gumbel; +using detail::laplace; +using detail::logistic; +using detail::lognormal; +using detail::normal; +using detail::normalInt; +using detail::normalTable; +using detail::rayleigh; +using detail::scaled_bernoulli; +using detail::uniform; +using detail::uniformInt; + +using detail::sampleWithoutReplacement; + class DEPR Rng : public detail::RngImpl { public: /**