diff --git a/cpp/include/raft/random/detail/make_blobs.cuh b/cpp/include/raft/random/detail/make_blobs.cuh index 10ded9c93e..ece49a0811 100644 --- a/cpp/include/raft/random/detail/make_blobs.cuh +++ b/cpp/include/raft/random/detail/make_blobs.cuh @@ -107,7 +107,9 @@ __global__ void generate_data_kernel(DataT* out, IdxT len = n_rows * n_cols; for (IdxT idx = tid; idx < len; idx += stride) { DataT val1, val2; - gen.next(val1); + do { + gen.next(val1); + } while (val1 == DataT(0.0)); gen.next(val2); DataT mu1, sigma1, mu2, sigma2; get_mu_sigma(mu1, diff --git a/cpp/include/raft/random/detail/rng_impl.cuh b/cpp/include/raft/random/detail/rng_impl.cuh index 2406456404..1b245ca45f 100644 --- a/cpp/include/raft/random/detail/rng_impl.cuh +++ b/cpp/include/raft/random/detail/rng_impl.cuh @@ -221,8 +221,13 @@ DI void custom_next( GenType& gen, OutType* val, NormalDistParams params, LenType idx = 0, LenType stride = 0) { OutType res1, res2; - gen.next(res1); + + do { + gen.next(res1); + } while (res1 == OutType(0.0)); + gen.next(res2); + box_muller_transform(res1, res2, params.sigma, params.mu); *val = res1; *(val + 1) = res2; @@ -236,7 +241,11 @@ DI void custom_next(GenType& gen, LenType stride = 0) { IntType res1_int, res2_int; - gen.next(res1_int); + + do { + gen.next(res1_int); + } while (res1_int == 0); + gen.next(res2_int); double res1 = static_cast(res1_int); double res2 = static_cast(res2_int); @@ -255,7 +264,11 @@ DI void custom_next(GenType& gen, LenType stride) { OutType res1, res2; - gen.next(res1); + + do { + gen.next(res1); + } while (res1 == OutType(0.0)); + gen.next(res2); LenType col1 = idx % params.n_cols; LenType col2 = (idx + stride) % params.n_cols; @@ -274,7 +287,7 @@ DI void custom_next( { Type res = 0; gen.next(res); - *val = res > params.prob; + *val = res < params.prob; } template @@ -286,7 +299,7 @@ DI void custom_next(GenType& gen, { OutType res = 0; gen.next(res); - *val = res > params.prob ? -params.scale : params.scale; + *val = res < params.prob ? -params.scale : params.scale; } template @@ -294,7 +307,11 @@ DI void custom_next( GenType& gen, OutType* val, GumbelDistParams params, LenType idx = 0, LenType stride = 0) { OutType res = 0; - gen.next(res); + + do { + gen.next(res); + } while (res == OutType(0.0)); + *val = params.mu - params.beta * raft::myLog(-raft::myLog(res)); } @@ -306,7 +323,10 @@ DI void custom_next(GenType& gen, LenType stride = 0) { OutType res1 = 0, res2 = 0; - gen.next(res1); + do { + gen.next(res1); + } while (res1 == OutType(0.0)); + gen.next(res2); box_muller_transform(res1, res2, params.sigma, params.mu); *val = raft::myExp(res1); @@ -321,7 +341,11 @@ DI void custom_next(GenType& gen, LenType stride = 0) { OutType res; - gen.next(res); + + do { + gen.next(res); + } while (res == OutType(0.0)); + constexpr OutType one = (OutType)1.0; *val = params.mu - params.scale * raft::myLog(one / res - one); } @@ -348,6 +372,7 @@ DI void custom_next(GenType& gen, { OutType res; gen.next(res); + constexpr OutType one = (OutType)1.0; constexpr OutType two = (OutType)2.0; *val = raft::mySqrt(-two * raft::myLog(one - res)) * params.sigma; @@ -361,10 +386,17 @@ DI void custom_next(GenType& gen, LenType stride = 0) { OutType res, out; - gen.next(res); + + do { + gen.next(res); + } while (res == OutType(0.0)); + constexpr OutType one = (OutType)1.0; constexpr OutType two = (OutType)2.0; constexpr OutType oneHalf = (OutType)0.5; + + // The <= comparison here means, number of samples going in `if` branch are more by 1 than `else` + // branch. However it does not matter as for 0.5 both branches evaluate to same result. if (res <= oneHalf) { out = params.mu + params.scale * raft::myLog(two * res); } else { @@ -451,8 +483,33 @@ struct PhiloxGenerator { return ret; } - DI void next(float& ret) { ret = curand_uniform(&(this->philox_state)); } - DI void next(double& ret) { ret = curand_uniform_double(&(this->philox_state)); } + DI float next_float() + { + float ret; + uint32_t val = next_u32() >> 8; + ret = static_cast(val) / float(uint32_t(1) << 24); + return ret; + } + + DI double next_double() + { + double ret; + uint64_t val = next_u64() >> 11; + ret = static_cast(val) / double(uint64_t(1) << 53); + return ret; + } + + DI void next(float& ret) + { + // ret = curand_uniform(&(this->philox_state)); + ret = next_float(); + } + + DI void next(double& ret) + { + // ret = curand_uniform_double(&(this->philox_state)); + ret = next_double(); + } DI void next(uint32_t& ret) { ret = next_u32(); } DI void next(uint64_t& ret) { ret = next_u64(); } diff --git a/cpp/test/linalg/gemm_layout.cu b/cpp/test/linalg/gemm_layout.cu index 422ba26f46..baf8cc00f4 100644 --- a/cpp/test/linalg/gemm_layout.cu +++ b/cpp/test/linalg/gemm_layout.cu @@ -128,7 +128,7 @@ const std::vector> inputsf = { {50, 10, 60, false, true, true, 73012ULL}, {90, 90, 30, false, true, false, 538147ULL}, {30, 100, 10, false, false, true, 412352ULL}, - {40, 80, 100, false, false, false, 297941ULL}}; + {40, 80, 100, false, false, false, 2979410ULL}}; const std::vector> inputsd = { {10, 70, 40, true, true, true, 535648ULL}, diff --git a/cpp/test/random/rng.cu b/cpp/test/random/rng.cu index 872ed25000..28e3e461c7 100644 --- a/cpp/test/random/rng.cu +++ b/cpp/test/random/rng.cu @@ -14,6 +14,8 @@ * limitations under the License. */ +#include + #include "../test_utils.h" #include #include @@ -58,28 +60,34 @@ __global__ void meanKernel(T* out, const T* data, int len) template struct RngInputs { - T tolerance; int len; - // start, end: for uniform - // mean, sigma: for normal/lognormal - // mean, beta: for gumbel - // mean, scale: for logistic and laplace - // lambda: for exponential - // sigma: for rayleigh + // Meaning of 'start' and 'end' parameter for various distributions + // + // Uniform Normal/Log-Normal Gumbel Logistic Laplace Exponential Rayleigh + // start start mean mean mean mean lambda sigma + // end end sigma beta scale scale Unused Unused T start, end; RandomType type; GeneratorType gtype; - unsigned long long int seed; + uint64_t seed; }; -template -::std::ostream& operator<<(::std::ostream& os, const RngInputs& dims) -{ - return os; -} - -#include -#include +// In this test we generate pseudo-random values that follow various probability distributions such +// as Normal, Laplace etc. To check the correctness of generated random variates we compute two +// measures, mean and variance from the generated data. The computed values are matched against +// their theoretically expected values for the corresponding distribution. The computed mean and +// variance are statistical variables themselves and follow a Normal distribution. Which means, +// there is 99+% chance that the computed values fall in the 3-sigma (standard deviation) interval +// [theoretical_value - 3*sigma, theoretical_value + 3*sigma]. The values are practically +// guaranteed to fall in the 4-sigma interval. Reference standard deviation of the computed +// mean/variance distribution is calculated here +// https://gist.github.com/vinaydes/cee04f50ff7e3365759603d39b7e079b Maximum standard deviation +// observed here is ~1.5e-2, thus we use this as sigma in our test. +// N O T E: Before adding any new test case below, make sure to calculate standard deviation for the +// test parameters using above notebook. + +constexpr int NUM_SIGMA = 4; +constexpr double MAX_SIGMA = 1.5e-2; template class RngTest : public ::testing::TestWithParam> { @@ -97,9 +105,6 @@ class RngTest : public ::testing::TestWithParam> { protected: void SetUp() override { - // Tests are configured with their expected test-values sigma. For example, - // 4 x sigma indicates the test shouldn't fail 99.9% of the time. - num_sigma = 4; Rng r(params.seed, params.gtype); switch (params.type) { case RNG_Normal: r.normal(data.data(), params.len, params.start, params.end, stream); break; @@ -176,118 +181,61 @@ class RngTest : public ::testing::TestWithParam> { RngInputs params; rmm::device_uvector data, stats; T h_stats[2]; // mean, var - int num_sigma; }; -// The measured mean and standard deviation for each tested distribution are, -// of course, statistical variables. Thus setting an appropriate testing -// tolerance essentially requires one to set a probability of test failure. We -// choose to set this at 3-4 x sigma, i.e., a 99.7-99.9% confidence interval so that -// the test will indeed pass. In quick experiments (using the identical -// distributions given by NumPy/SciPy), the measured standard deviation is the -// variable with the greatest variance and so we determined the variance for -// each distribution and number of samples (32*1024 or 8*1024). Below -// are listed the standard deviation for these tests. - -// Distribution: StdDev 32*1024, StdDev 8*1024 -// Normal: 0.0055, 0.011 -// LogNormal: 0.05, 0.1 -// Uniform: 0.003, 0.005 -// Gumbel: 0.005, 0.01 -// Logistic: 0.005, 0.01 -// Exp: 0.008, 0.015 -// Rayleigh: 0.0125, 0.025 -// Laplace: 0.02, 0.04 - -// We generally want 4 x sigma >= 99.9% chance of success - typedef RngTest RngTestF; const std::vector> inputsf = { - {0.0055, 32 * 1024, 1.f, 1.f, RNG_Normal, GenPhilox, 1234ULL}, - {0.011, 8 * 1024, 1.f, 1.f, RNG_Normal, GenPhilox, 1234ULL}, - {0.05, 32 * 1024, 1.f, 1.f, RNG_LogNormal, GenPhilox, 1234ULL}, - {0.1, 8 * 1024, 1.f, 1.f, RNG_LogNormal, GenPhilox, 1234ULL}, - {0.003, 32 * 1024, -1.f, 1.f, RNG_Uniform, GenPhilox, 1234ULL}, - {0.005, 8 * 1024, -1.f, 1.f, RNG_Uniform, GenPhilox, 1234ULL}, - {0.005, 32 * 1024, 1.f, 1.f, RNG_Gumbel, GenPhilox, 1234ULL}, - {0.01, 8 * 1024, 1.f, 1.f, RNG_Gumbel, GenPhilox, 1234ULL}, - {0.005, 32 * 1024, 1.f, 1.f, RNG_Logistic, GenPhilox, 67632ULL}, - {0.01, 8 * 1024, 1.f, 1.f, RNG_Logistic, GenPhilox, 1234ULL}, - {0.008, 32 * 1024, 1.f, 1.f, RNG_Exp, GenPhilox, 1234ULL}, - {0.015, 8 * 1024, 1.f, 1.f, RNG_Exp, GenPhilox, 1234ULL}, - {0.0125, 32 * 1024, 1.f, 1.f, RNG_Rayleigh, GenPhilox, 1234ULL}, - {0.025, 8 * 1024, 1.f, 1.f, RNG_Rayleigh, GenPhilox, 1234ULL}, - {0.02, 32 * 1024, 1.f, 1.f, RNG_Laplace, GenPhilox, 1234ULL}, - {0.04, 8 * 1024, 1.f, 1.f, RNG_Laplace, GenPhilox, 1234ULL}, - - {0.0055, 32 * 1024, 1.f, 1.f, RNG_Normal, GenPC, 1234ULL}, - {0.011, 8 * 1024, 1.f, 1.f, RNG_Normal, GenPC, 1234ULL}, - {0.05, 32 * 1024, 1.f, 1.f, RNG_LogNormal, GenPC, 1234ULL}, - {0.1, 8 * 1024, 1.f, 1.f, RNG_LogNormal, GenPC, 1234ULL}, - {0.003, 32 * 1024, -1.f, 1.f, RNG_Uniform, GenPC, 1234ULL}, - {0.005, 8 * 1024, -1.f, 1.f, RNG_Uniform, GenPC, 1234ULL}, - {0.005, 32 * 1024, 1.f, 1.f, RNG_Gumbel, GenPC, 1234ULL}, - {0.01, 8 * 1024, 1.f, 1.f, RNG_Gumbel, GenPC, 1234ULL}, - {0.005, 32 * 1024, 1.f, 1.f, RNG_Logistic, GenPC, 1234ULL}, - {0.01, 8 * 1024, 1.f, 1.f, RNG_Logistic, GenPC, 1234ULL}, - {0.008, 32 * 1024, 1.f, 1.f, RNG_Exp, GenPC, 1234ULL}, - {0.015, 8 * 1024, 1.f, 1.f, RNG_Exp, GenPC, 1234ULL}, - {0.0125, 32 * 1024, 1.f, 1.f, RNG_Rayleigh, GenPC, 1234ULL}, - {0.025, 8 * 1024, 1.f, 1.f, RNG_Rayleigh, GenPC, 1234ULL}, - {0.02, 32 * 1024, 1.f, 1.f, RNG_Laplace, GenPC, 1234ULL}, - {0.04, 8 * 1024, 1.f, 1.f, RNG_Laplace, GenPC, 1234ULL}}; + // Test with Philox + {1024 * 1024, 3.0f, 1.3f, RNG_Normal, GenPhilox, 1234ULL}, + {1024 * 1024, 1.2f, 0.1f, RNG_LogNormal, GenPhilox, 1234ULL}, + {1024 * 1024, 1.2f, 5.5f, RNG_Uniform, GenPhilox, 1234ULL}, + {1024 * 1024, 0.1f, 1.3f, RNG_Gumbel, GenPhilox, 1234ULL}, + {1024 * 1024, 1.6f, 0.0f, RNG_Exp, GenPhilox, 1234ULL}, + {1024 * 1024, 1.6f, 0.0f, RNG_Rayleigh, GenPhilox, 1234ULL}, + {1024 * 1024, 2.6f, 1.3f, RNG_Laplace, GenPhilox, 1234ULL}, + // Test with PCG + {1024 * 1024, 3.0f, 1.3f, RNG_Normal, GenPC, 1234ULL}, + {1024 * 1024, 1.2f, 0.1f, RNG_LogNormal, GenPC, 1234ULL}, + {1024 * 1024, 1.2f, 5.5f, RNG_Uniform, GenPC, 1234ULL}, + {1024 * 1024, 0.1f, 1.3f, RNG_Gumbel, GenPC, 1234ULL}, + {1024 * 1024, 1.6f, 0.0f, RNG_Exp, GenPC, 1234ULL}, + {1024 * 1024, 1.6f, 0.0f, RNG_Rayleigh, GenPC, 1234ULL}, + {1024 * 1024, 2.6f, 1.3f, RNG_Laplace, GenPC, 1234ULL}}; TEST_P(RngTestF, Result) { float meanvar[2]; getExpectedMeanVar(meanvar); - ASSERT_TRUE(match(meanvar[0], h_stats[0], CompareApprox(num_sigma * params.tolerance))); - ASSERT_TRUE(match(meanvar[1], h_stats[1], CompareApprox(num_sigma * params.tolerance))); + ASSERT_TRUE(match(meanvar[0], h_stats[0], CompareApprox(NUM_SIGMA * MAX_SIGMA))); + ASSERT_TRUE(match(meanvar[1], h_stats[1], CompareApprox(NUM_SIGMA * MAX_SIGMA))); } INSTANTIATE_TEST_SUITE_P(RngTests, RngTestF, ::testing::ValuesIn(inputsf)); typedef RngTest RngTestD; const std::vector> inputsd = { - {0.0055, 32 * 1024, 1.0, 1.0, RNG_Normal, GenPhilox, 1234ULL}, - {0.011, 8 * 1024, 1.0, 1.0, RNG_Normal, GenPhilox, 1234ULL}, - {0.05, 32 * 1024, 1.0, 1.0, RNG_LogNormal, GenPhilox, 1234ULL}, - {0.1, 8 * 1024, 1.0, 1.0, RNG_LogNormal, GenPhilox, 1234ULL}, - {0.003, 32 * 1024, -1.0, 1.0, RNG_Uniform, GenPhilox, 1234ULL}, - {0.005, 8 * 1024, -1.0, 1.0, RNG_Uniform, GenPhilox, 1234ULL}, - {0.005, 32 * 1024, 1.0, 1.0, RNG_Gumbel, GenPhilox, 1234ULL}, - {0.01, 8 * 1024, 1.0, 1.0, RNG_Gumbel, GenPhilox, 1234ULL}, - {0.005, 32 * 1024, 1.0, 1.0, RNG_Logistic, GenPhilox, 67632ULL}, - {0.01, 8 * 1024, 1.0, 1.0, RNG_Logistic, GenPhilox, 1234ULL}, - {0.008, 32 * 1024, 1.0, 1.0, RNG_Exp, GenPhilox, 1234ULL}, - {0.015, 8 * 1024, 1.0, 1.0, RNG_Exp, GenPhilox, 1234ULL}, - {0.0125, 32 * 1024, 1.0, 1.0, RNG_Rayleigh, GenPhilox, 1234ULL}, - {0.025, 8 * 1024, 1.0, 1.0, RNG_Rayleigh, GenPhilox, 1234ULL}, - {0.02, 32 * 1024, 1.0, 1.0, RNG_Laplace, GenPhilox, 1234ULL}, - {0.04, 8 * 1024, 1.0, 1.0, RNG_Laplace, GenPhilox, 1234ULL}, - - {0.0055, 32 * 1024, 1.0, 1.0, RNG_Normal, GenPC, 1234ULL}, - {0.011, 8 * 1024, 1.0, 1.0, RNG_Normal, GenPC, 1234ULL}, - {0.05, 32 * 1024, 1.0, 1.0, RNG_LogNormal, GenPC, 1234ULL}, - {0.1, 8 * 1024, 1.0, 1.0, RNG_LogNormal, GenPC, 1234ULL}, - {0.003, 32 * 1024, -1.0, 1.0, RNG_Uniform, GenPC, 1234ULL}, - {0.005, 8 * 1024, -1.0, 1.0, RNG_Uniform, GenPC, 1234ULL}, - {0.005, 32 * 1024, 1.0, 1.0, RNG_Gumbel, GenPC, 1234ULL}, - {0.01, 8 * 1024, 1.0, 1.0, RNG_Gumbel, GenPC, 1234ULL}, - {0.005, 32 * 1024, 1.0, 1.0, RNG_Logistic, GenPC, 1234ULL}, - {0.01, 8 * 1024, 1.0, 1.0, RNG_Logistic, GenPC, 1234ULL}, - {0.008, 32 * 1024, 1.0, 1.0, RNG_Exp, GenPC, 1234ULL}, - {0.015, 8 * 1024, 1.0, 1.0, RNG_Exp, GenPC, 1234ULL}, - {0.0125, 32 * 1024, 1.0, 1.0, RNG_Rayleigh, GenPC, 1234ULL}, - {0.025, 8 * 1024, 1.0, 1.0, RNG_Rayleigh, GenPC, 1234ULL}, - {0.02, 32 * 1024, 1.0, 1.0, RNG_Laplace, GenPC, 1234ULL}, - {0.04, 8 * 1024, 1.0, 1.0, RNG_Laplace, GenPC, 1234ULL}}; + // Test with Philox + {1024 * 1024, 3.0f, 1.3f, RNG_Normal, GenPhilox, 1234ULL}, + {1024 * 1024, 1.2f, 0.1f, RNG_LogNormal, GenPhilox, 1234ULL}, + {1024 * 1024, 1.2f, 5.5f, RNG_Uniform, GenPhilox, 1234ULL}, + {1024 * 1024, 0.1f, 1.3f, RNG_Gumbel, GenPhilox, 1234ULL}, + {1024 * 1024, 1.6f, 0.0f, RNG_Exp, GenPhilox, 1234ULL}, + {1024 * 1024, 1.6f, 0.0f, RNG_Rayleigh, GenPhilox, 1234ULL}, + {1024 * 1024, 2.6f, 1.3f, RNG_Laplace, GenPhilox, 1234ULL}, + // Test with PCG + {1024 * 1024, 3.0f, 1.3f, RNG_Normal, GenPC, 1234ULL}, + {1024 * 1024, 1.2f, 0.1f, RNG_LogNormal, GenPC, 1234ULL}, + {1024 * 1024, 1.2f, 5.5f, RNG_Uniform, GenPC, 1234ULL}, + {1024 * 1024, 0.1f, 1.3f, RNG_Gumbel, GenPC, 1234ULL}, + {1024 * 1024, 1.6f, 0.0f, RNG_Exp, GenPC, 1234ULL}, + {1024 * 1024, 1.6f, 0.0f, RNG_Rayleigh, GenPC, 1234ULL}, + {1024 * 1024, 2.6f, 1.3f, RNG_Laplace, GenPC, 1234ULL}}; TEST_P(RngTestD, Result) { double meanvar[2]; getExpectedMeanVar(meanvar); - ASSERT_TRUE(match(meanvar[0], h_stats[0], CompareApprox(num_sigma * params.tolerance))); - ASSERT_TRUE(match(meanvar[1], h_stats[1], CompareApprox(num_sigma * params.tolerance))); + ASSERT_TRUE(match(meanvar[0], h_stats[0], CompareApprox(NUM_SIGMA * MAX_SIGMA))); + ASSERT_TRUE(match(meanvar[1], h_stats[1], CompareApprox(NUM_SIGMA * MAX_SIGMA))); } INSTANTIATE_TEST_SUITE_P(RngTests, RngTestD, ::testing::ValuesIn(inputsd)); @@ -326,12 +274,11 @@ std::ostream& operator<<(std::ostream& out, const std::vector& v) return out; } -// The following tests the 3 random number generators by checking that the -// measured mean error is close to the well-known analytical result -// (sigma/sqrt(n_samples)). To compute the mean error, we a number of -// experiments computing the mean, giving us a distribution of the mean -// itself. The mean error is simply the standard deviation of this -// distribution (the standard deviation of the mean). +// The following tests the two random number generators by checking that the measured mean error is +// close to the well-known analytical result(sigma/sqrt(n_samples)). To compute the mean error, we +// a number of experiments computing the mean, giving us a distribution of the mean itself. The +// mean error is simply the standard deviation of this distribution (the standard deviation of the +// mean). TEST(Rng, MeanError) { timeb time_struct; @@ -380,7 +327,8 @@ TEST(Rng, MeanError) auto diff_expected_vs_measured_mean_error = std::abs(d_std_of_mean - d_std / std::sqrt(num_samples)); - ASSERT_TRUE((diff_expected_vs_measured_mean_error / d_std_of_mean_analytical < 0.5)); + ASSERT_TRUE((diff_expected_vs_measured_mean_error / d_std_of_mean_analytical < 0.5)) + << "Failed with seed: " << seed << "\nrtype: " << rtype; } RAFT_CUDA_TRY(cudaStreamDestroy(stream)); diff --git a/cpp/test/spatial/epsilon_neighborhood.cu b/cpp/test/spatial/epsilon_neighborhood.cu index 30cd79188b..c005549b04 100644 --- a/cpp/test/spatial/epsilon_neighborhood.cu +++ b/cpp/test/spatial/epsilon_neighborhood.cu @@ -93,26 +93,10 @@ TEST_P(EpsNeighTestFI, Result) for (int i = 0; i < param.n_batches; ++i) { RAFT_CUDA_TRY(cudaMemsetAsync(adj.data(), 0, sizeof(bool) * param.n_row * batchSize, stream)); RAFT_CUDA_TRY(cudaMemsetAsync(vd.data(), 0, sizeof(int) * (batchSize + 1), stream)); - epsUnexpL2SqNeighborhood(adj. - - data(), - vd - - . - - data(), - data - - . - - data(), - data - - . - - data() - - + (i * batchSize * param.n_col), + epsUnexpL2SqNeighborhood(adj.data(), + vd.data(), + data.data(), + data.data() + (i * batchSize * param.n_col), param.n_row, batchSize, param.n_col,