Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] RNG test fixes and improvements #513

Merged
merged 23 commits into from
Mar 8, 2022
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
c9ed5a5
Updating the test description
vinaydes Feb 17, 2022
5db25ee
Correcting the seed type and updating parameter description
vinaydes Feb 17, 2022
873a2c6
Changing the test case parameters
vinaydes Feb 17, 2022
672e455
Fixing NaNs/Infs in the unit tests
vinaydes Feb 17, 2022
727dbc3
Formatting
vinaydes Feb 17, 2022
555339e
Moving the test description
vinaydes Feb 17, 2022
3676480
Correcting tolerances
vinaydes Feb 17, 2022
a677246
Consolidating tolerance parameters
vinaydes Feb 17, 2022
83dd813
Removing unused function
vinaydes Feb 17, 2022
dc14296
Moving include statement to correct place
vinaydes Feb 17, 2022
500864d
Replacing cuRAND uniform calls to change the range of generation from…
vinaydes Feb 24, 2022
422c37f
Added a temporary test
vinaydes Feb 24, 2022
8f53e4a
Removing check on res2 as it is not needed
vinaydes Feb 25, 2022
8c60042
Adding check for all Box-Muller calls
vinaydes Feb 25, 2022
98c2b2e
Fixing the log related checks
vinaydes Feb 25, 2022
758863f
Removing the debug test
vinaydes Feb 25, 2022
b3ebe3b
Formatting fixes
vinaydes Feb 25, 2022
be22797
Removing the possibility of passing zero to Box Muller transform
vinaydes Mar 2, 2022
1bbe79b
Changing the seed for a test case to avoid degenerate case
vinaydes Mar 2, 2022
ebb1a62
Improving format
vinaydes Mar 2, 2022
9e6853c
Correcting the Bernoulli dist generation
vinaydes Mar 8, 2022
7224a26
Adding comment explaining Laplace implementation
vinaydes Mar 8, 2022
e024451
Making integer to float/double conversion explicit
vinaydes Mar 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions cpp/include/raft/random/detail/rng_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,11 @@ DI void custom_next(
GenType& gen, OutType* val, GumbelDistParams<OutType> params, LenType idx = 0, LenType stride = 0)
{
OutType res = 0;
gen.next(res);

do {
gen.next(res);
} while (res == OutType(1.0));
vinaydes marked this conversation as resolved.
Show resolved Hide resolved

*val = params.mu - params.beta * raft::myLog(-raft::myLog(res));
}

Expand Down Expand Up @@ -334,7 +338,11 @@ DI void custom_next(GenType& gen,
LenType stride = 0)
{
OutType res;
gen.next(res);

do {
gen.next(res);
} while (res == OutType(1.0));

constexpr OutType one = (OutType)1.0;
*val = -raft::myLog(one - res) / params.lambda;
}
Expand All @@ -347,7 +355,11 @@ DI void custom_next(GenType& gen,
LenType stride = 0)
{
OutType res;
gen.next(res);

do {
gen.next(res);
} while (res == OutType(1.0));

constexpr OutType one = (OutType)1.0;
constexpr OutType two = (OutType)2.0;
*val = raft::mySqrt(-two * raft::myLog(one - res)) * params.sigma;
Expand All @@ -361,7 +373,11 @@ DI void custom_next(GenType& gen,
LenType stride = 0)
{
OutType res, out;
gen.next(res);

do {
gen.next(res);
} while (res == OutType(0.0) || res == OutType(1.0));

constexpr OutType one = (OutType)1.0;
constexpr OutType two = (OutType)2.0;
constexpr OutType oneHalf = (OutType)0.5;
Expand Down
172 changes: 60 additions & 112 deletions cpp/test/random/rng.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
* limitations under the License.
*/

#include <sys/timeb.h>

#include "../test_utils.h"
#include <cub/cub.cuh>
#include <gtest/gtest.h>
Expand Down Expand Up @@ -58,28 +60,34 @@ __global__ void meanKernel(T* out, const T* data, int len)

template <typename T>
struct RngInputs {
T tolerance;
int len;
// start, end: for uniform
// mean, sigma: for normal/lognormal
// mean, beta: for gumbel
// mean, scale: for logistic and laplace
// lambda: for exponential
// sigma: for rayleigh
// Meaning of 'start' and 'end' parameter for various distributions
//
// Uniform Normal/Log-Normal Gumbel Logistic Laplace Exponential Rayleigh
// start start mean mean mean mean lambda sigma
// end end sigma beta scale scale Unused Unused
T start, end;
RandomType type;
GeneratorType gtype;
unsigned long long int seed;
uint64_t seed;
};

template <typename T>
::std::ostream& operator<<(::std::ostream& os, const RngInputs<T>& dims)
{
return os;
}

#include <sys/timeb.h>
#include <time.h>
// In this test we generate pseudo-random values that follow various probability distributions such
// as Normal, Laplace etc. To check the correctness of generated random variates we compute two
// measures, mean and variance from the generated data. The computed values are matched against
// their theoretically expected values for the corresponding distribution. The computed mean and
// variance are statistical variables themselves and follow a Normal distribution. Which means,
// there is 99+% chance that the computed values fall in the 3-sigma (standard deviation) interval
// [theoretical_value - 3*sigma, theoretical_value + 3*sigma]. The values are practically
// guaranteed to fall in the 4-sigma interval. Reference standard deviation of the computed
// mean/variance distribution is calculated here
// https://gist.github.com/vinaydes/cee04f50ff7e3365759603d39b7e079b Maximum standard deviation
// observed here is ~1.5e-2, thus we use this as sigma in our test.
// N O T E: Before adding any new test case below, make sure to calculate standard deviation for the
// test parameters using above notebook.

constexpr int NUM_SIGMA = 4;
constexpr double MAX_SIGMA = 1.5e-2;

template <typename T>
class RngTest : public ::testing::TestWithParam<RngInputs<T>> {
Expand All @@ -97,9 +105,6 @@ class RngTest : public ::testing::TestWithParam<RngInputs<T>> {
protected:
void SetUp() override
{
// Tests are configured with their expected test-values sigma. For example,
// 4 x sigma indicates the test shouldn't fail 99.9% of the time.
num_sigma = 4;
Rng r(params.seed, params.gtype);
switch (params.type) {
case RNG_Normal: r.normal(data.data(), params.len, params.start, params.end, stream); break;
Expand Down Expand Up @@ -176,118 +181,61 @@ class RngTest : public ::testing::TestWithParam<RngInputs<T>> {
RngInputs<T> params;
rmm::device_uvector<T> data, stats;
T h_stats[2]; // mean, var
int num_sigma;
};

// The measured mean and standard deviation for each tested distribution are,
// of course, statistical variables. Thus setting an appropriate testing
// tolerance essentially requires one to set a probability of test failure. We
// choose to set this at 3-4 x sigma, i.e., a 99.7-99.9% confidence interval so that
// the test will indeed pass. In quick experiments (using the identical
// distributions given by NumPy/SciPy), the measured standard deviation is the
// variable with the greatest variance and so we determined the variance for
// each distribution and number of samples (32*1024 or 8*1024). Below
// are listed the standard deviation for these tests.

// Distribution: StdDev 32*1024, StdDev 8*1024
// Normal: 0.0055, 0.011
// LogNormal: 0.05, 0.1
// Uniform: 0.003, 0.005
// Gumbel: 0.005, 0.01
// Logistic: 0.005, 0.01
// Exp: 0.008, 0.015
// Rayleigh: 0.0125, 0.025
// Laplace: 0.02, 0.04

// We generally want 4 x sigma >= 99.9% chance of success

typedef RngTest<float> RngTestF;
const std::vector<RngInputs<float>> inputsf = {
{0.0055, 32 * 1024, 1.f, 1.f, RNG_Normal, GenPhilox, 1234ULL},
{0.011, 8 * 1024, 1.f, 1.f, RNG_Normal, GenPhilox, 1234ULL},
{0.05, 32 * 1024, 1.f, 1.f, RNG_LogNormal, GenPhilox, 1234ULL},
{0.1, 8 * 1024, 1.f, 1.f, RNG_LogNormal, GenPhilox, 1234ULL},
{0.003, 32 * 1024, -1.f, 1.f, RNG_Uniform, GenPhilox, 1234ULL},
{0.005, 8 * 1024, -1.f, 1.f, RNG_Uniform, GenPhilox, 1234ULL},
{0.005, 32 * 1024, 1.f, 1.f, RNG_Gumbel, GenPhilox, 1234ULL},
{0.01, 8 * 1024, 1.f, 1.f, RNG_Gumbel, GenPhilox, 1234ULL},
{0.005, 32 * 1024, 1.f, 1.f, RNG_Logistic, GenPhilox, 67632ULL},
{0.01, 8 * 1024, 1.f, 1.f, RNG_Logistic, GenPhilox, 1234ULL},
{0.008, 32 * 1024, 1.f, 1.f, RNG_Exp, GenPhilox, 1234ULL},
{0.015, 8 * 1024, 1.f, 1.f, RNG_Exp, GenPhilox, 1234ULL},
{0.0125, 32 * 1024, 1.f, 1.f, RNG_Rayleigh, GenPhilox, 1234ULL},
{0.025, 8 * 1024, 1.f, 1.f, RNG_Rayleigh, GenPhilox, 1234ULL},
{0.02, 32 * 1024, 1.f, 1.f, RNG_Laplace, GenPhilox, 1234ULL},
{0.04, 8 * 1024, 1.f, 1.f, RNG_Laplace, GenPhilox, 1234ULL},

{0.0055, 32 * 1024, 1.f, 1.f, RNG_Normal, GenPC, 1234ULL},
{0.011, 8 * 1024, 1.f, 1.f, RNG_Normal, GenPC, 1234ULL},
{0.05, 32 * 1024, 1.f, 1.f, RNG_LogNormal, GenPC, 1234ULL},
{0.1, 8 * 1024, 1.f, 1.f, RNG_LogNormal, GenPC, 1234ULL},
{0.003, 32 * 1024, -1.f, 1.f, RNG_Uniform, GenPC, 1234ULL},
{0.005, 8 * 1024, -1.f, 1.f, RNG_Uniform, GenPC, 1234ULL},
{0.005, 32 * 1024, 1.f, 1.f, RNG_Gumbel, GenPC, 1234ULL},
{0.01, 8 * 1024, 1.f, 1.f, RNG_Gumbel, GenPC, 1234ULL},
{0.005, 32 * 1024, 1.f, 1.f, RNG_Logistic, GenPC, 1234ULL},
{0.01, 8 * 1024, 1.f, 1.f, RNG_Logistic, GenPC, 1234ULL},
{0.008, 32 * 1024, 1.f, 1.f, RNG_Exp, GenPC, 1234ULL},
{0.015, 8 * 1024, 1.f, 1.f, RNG_Exp, GenPC, 1234ULL},
{0.0125, 32 * 1024, 1.f, 1.f, RNG_Rayleigh, GenPC, 1234ULL},
{0.025, 8 * 1024, 1.f, 1.f, RNG_Rayleigh, GenPC, 1234ULL},
{0.02, 32 * 1024, 1.f, 1.f, RNG_Laplace, GenPC, 1234ULL},
{0.04, 8 * 1024, 1.f, 1.f, RNG_Laplace, GenPC, 1234ULL}};
// Test with Philox
{1024 * 1024, 3.0f, 1.3f, RNG_Normal, GenPhilox, 1234ULL},
{1024 * 1024, 1.2f, 0.1f, RNG_LogNormal, GenPhilox, 1234ULL},
{1024 * 1024, 1.2f, 5.5f, RNG_Uniform, GenPhilox, 1234ULL},
{1024 * 1024, 0.1f, 1.3f, RNG_Gumbel, GenPhilox, 1234ULL},
{1024 * 1024, 1.6f, 0.0f, RNG_Exp, GenPhilox, 1234ULL},
{1024 * 1024, 1.6f, 0.0f, RNG_Rayleigh, GenPhilox, 1234ULL},
{1024 * 1024, 2.6f, 1.3f, RNG_Laplace, GenPhilox, 1234ULL},
// Test with PCG
{1024 * 1024, 3.0f, 1.3f, RNG_Normal, GenPC, 1234ULL},
{1024 * 1024, 1.2f, 0.1f, RNG_LogNormal, GenPC, 1234ULL},
{1024 * 1024, 1.2f, 5.5f, RNG_Uniform, GenPC, 1234ULL},
{1024 * 1024, 0.1f, 1.3f, RNG_Gumbel, GenPC, 1234ULL},
{1024 * 1024, 1.6f, 0.0f, RNG_Exp, GenPC, 1234ULL},
{1024 * 1024, 1.6f, 0.0f, RNG_Rayleigh, GenPC, 1234ULL},
{1024 * 1024, 2.6f, 1.3f, RNG_Laplace, GenPC, 1234ULL}};

TEST_P(RngTestF, Result)
{
float meanvar[2];
getExpectedMeanVar(meanvar);
ASSERT_TRUE(match(meanvar[0], h_stats[0], CompareApprox<float>(num_sigma * params.tolerance)));
ASSERT_TRUE(match(meanvar[1], h_stats[1], CompareApprox<float>(num_sigma * params.tolerance)));
ASSERT_TRUE(match(meanvar[0], h_stats[0], CompareApprox<float>(NUM_SIGMA * MAX_SIGMA)));
ASSERT_TRUE(match(meanvar[1], h_stats[1], CompareApprox<float>(NUM_SIGMA * MAX_SIGMA)));
}
INSTANTIATE_TEST_SUITE_P(RngTests, RngTestF, ::testing::ValuesIn(inputsf));

typedef RngTest<double> RngTestD;
const std::vector<RngInputs<double>> inputsd = {
{0.0055, 32 * 1024, 1.0, 1.0, RNG_Normal, GenPhilox, 1234ULL},
{0.011, 8 * 1024, 1.0, 1.0, RNG_Normal, GenPhilox, 1234ULL},
{0.05, 32 * 1024, 1.0, 1.0, RNG_LogNormal, GenPhilox, 1234ULL},
{0.1, 8 * 1024, 1.0, 1.0, RNG_LogNormal, GenPhilox, 1234ULL},
{0.003, 32 * 1024, -1.0, 1.0, RNG_Uniform, GenPhilox, 1234ULL},
{0.005, 8 * 1024, -1.0, 1.0, RNG_Uniform, GenPhilox, 1234ULL},
{0.005, 32 * 1024, 1.0, 1.0, RNG_Gumbel, GenPhilox, 1234ULL},
{0.01, 8 * 1024, 1.0, 1.0, RNG_Gumbel, GenPhilox, 1234ULL},
{0.005, 32 * 1024, 1.0, 1.0, RNG_Logistic, GenPhilox, 67632ULL},
{0.01, 8 * 1024, 1.0, 1.0, RNG_Logistic, GenPhilox, 1234ULL},
{0.008, 32 * 1024, 1.0, 1.0, RNG_Exp, GenPhilox, 1234ULL},
{0.015, 8 * 1024, 1.0, 1.0, RNG_Exp, GenPhilox, 1234ULL},
{0.0125, 32 * 1024, 1.0, 1.0, RNG_Rayleigh, GenPhilox, 1234ULL},
{0.025, 8 * 1024, 1.0, 1.0, RNG_Rayleigh, GenPhilox, 1234ULL},
{0.02, 32 * 1024, 1.0, 1.0, RNG_Laplace, GenPhilox, 1234ULL},
{0.04, 8 * 1024, 1.0, 1.0, RNG_Laplace, GenPhilox, 1234ULL},

{0.0055, 32 * 1024, 1.0, 1.0, RNG_Normal, GenPC, 1234ULL},
{0.011, 8 * 1024, 1.0, 1.0, RNG_Normal, GenPC, 1234ULL},
{0.05, 32 * 1024, 1.0, 1.0, RNG_LogNormal, GenPC, 1234ULL},
{0.1, 8 * 1024, 1.0, 1.0, RNG_LogNormal, GenPC, 1234ULL},
{0.003, 32 * 1024, -1.0, 1.0, RNG_Uniform, GenPC, 1234ULL},
{0.005, 8 * 1024, -1.0, 1.0, RNG_Uniform, GenPC, 1234ULL},
{0.005, 32 * 1024, 1.0, 1.0, RNG_Gumbel, GenPC, 1234ULL},
{0.01, 8 * 1024, 1.0, 1.0, RNG_Gumbel, GenPC, 1234ULL},
{0.005, 32 * 1024, 1.0, 1.0, RNG_Logistic, GenPC, 1234ULL},
{0.01, 8 * 1024, 1.0, 1.0, RNG_Logistic, GenPC, 1234ULL},
{0.008, 32 * 1024, 1.0, 1.0, RNG_Exp, GenPC, 1234ULL},
{0.015, 8 * 1024, 1.0, 1.0, RNG_Exp, GenPC, 1234ULL},
{0.0125, 32 * 1024, 1.0, 1.0, RNG_Rayleigh, GenPC, 1234ULL},
{0.025, 8 * 1024, 1.0, 1.0, RNG_Rayleigh, GenPC, 1234ULL},
{0.02, 32 * 1024, 1.0, 1.0, RNG_Laplace, GenPC, 1234ULL},
{0.04, 8 * 1024, 1.0, 1.0, RNG_Laplace, GenPC, 1234ULL}};
// Test with Philox
{1024 * 1024, 3.0f, 1.3f, RNG_Normal, GenPhilox, 1234ULL},
{1024 * 1024, 1.2f, 0.1f, RNG_LogNormal, GenPhilox, 1234ULL},
{1024 * 1024, 1.2f, 5.5f, RNG_Uniform, GenPhilox, 1234ULL},
{1024 * 1024, 0.1f, 1.3f, RNG_Gumbel, GenPhilox, 1234ULL},
{1024 * 1024, 1.6f, 0.0f, RNG_Exp, GenPhilox, 1234ULL},
{1024 * 1024, 1.6f, 0.0f, RNG_Rayleigh, GenPhilox, 1234ULL},
{1024 * 1024, 2.6f, 1.3f, RNG_Laplace, GenPhilox, 1234ULL},
// Test with PCG
{1024 * 1024, 3.0f, 1.3f, RNG_Normal, GenPC, 1234ULL},
{1024 * 1024, 1.2f, 0.1f, RNG_LogNormal, GenPC, 1234ULL},
{1024 * 1024, 1.2f, 5.5f, RNG_Uniform, GenPC, 1234ULL},
{1024 * 1024, 0.1f, 1.3f, RNG_Gumbel, GenPC, 1234ULL},
{1024 * 1024, 1.6f, 0.0f, RNG_Exp, GenPC, 1234ULL},
{1024 * 1024, 1.6f, 0.0f, RNG_Rayleigh, GenPC, 1234ULL},
{1024 * 1024, 2.6f, 1.3f, RNG_Laplace, GenPC, 1234ULL}};

TEST_P(RngTestD, Result)
{
double meanvar[2];
getExpectedMeanVar(meanvar);
ASSERT_TRUE(match(meanvar[0], h_stats[0], CompareApprox<double>(num_sigma * params.tolerance)));
ASSERT_TRUE(match(meanvar[1], h_stats[1], CompareApprox<double>(num_sigma * params.tolerance)));
ASSERT_TRUE(match(meanvar[0], h_stats[0], CompareApprox<double>(NUM_SIGMA * MAX_SIGMA)));
ASSERT_TRUE(match(meanvar[1], h_stats[1], CompareApprox<double>(NUM_SIGMA * MAX_SIGMA)));
}
INSTANTIATE_TEST_SUITE_P(RngTests, RngTestD, ::testing::ValuesIn(inputsd));

Expand Down