Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] RNG test fixes and improvements #513

Merged
merged 23 commits into from
Mar 8, 2022
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
c9ed5a5
Updating the test description
vinaydes Feb 17, 2022
5db25ee
Correcting the seed type and updating parameter description
vinaydes Feb 17, 2022
873a2c6
Changing the test case parameters
vinaydes Feb 17, 2022
672e455
Fixing NaNs/Infs in the unit tests
vinaydes Feb 17, 2022
727dbc3
Formatting
vinaydes Feb 17, 2022
555339e
Moving the test description
vinaydes Feb 17, 2022
3676480
Correcting tolerances
vinaydes Feb 17, 2022
a677246
Consolidating tolerance parameters
vinaydes Feb 17, 2022
83dd813
Removing unused function
vinaydes Feb 17, 2022
dc14296
Moving include statement to correct place
vinaydes Feb 17, 2022
500864d
Replacing cuRAND uniform calls to change the range of generation from…
vinaydes Feb 24, 2022
422c37f
Added a temporary test
vinaydes Feb 24, 2022
8f53e4a
Removing check on res2 as it is not needed
vinaydes Feb 25, 2022
8c60042
Adding check for all Box-Muller calls
vinaydes Feb 25, 2022
98c2b2e
Fixing the log related checks
vinaydes Feb 25, 2022
758863f
Removing the debug test
vinaydes Feb 25, 2022
b3ebe3b
Formatting fixes
vinaydes Feb 25, 2022
be22797
Removing the possibility of passing zero to Box Muller transform
vinaydes Mar 2, 2022
1bbe79b
Changing the seed for a test case to avoid degenerate case
vinaydes Mar 2, 2022
ebb1a62
Improving format
vinaydes Mar 2, 2022
9e6853c
Correcting the Bernoulli dist generation
vinaydes Mar 8, 2022
7224a26
Adding comment explaining Laplace implementation
vinaydes Mar 8, 2022
e024451
Making integer to float/double conversion explicit
vinaydes Mar 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cpp/include/raft/random/detail/make_blobs.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ __global__ void generate_data_kernel(DataT* out,
IdxT len = n_rows * n_cols;
for (IdxT idx = tid; idx < len; idx += stride) {
DataT val1, val2;
gen.next(val1);
do {
gen.next(val1);
} while (val1 == DataT(0.0));
gen.next(val2);
DataT mu1, sigma1, mu2, sigma2;
get_mu_sigma(mu1,
Expand Down
72 changes: 63 additions & 9 deletions cpp/include/raft/random/detail/rng_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,13 @@ DI void custom_next(
GenType& gen, OutType* val, NormalDistParams<OutType> params, LenType idx = 0, LenType stride = 0)
{
OutType res1, res2;
gen.next(res1);

do {
gen.next(res1);
} while (res1 == OutType(0.0));

gen.next(res2);

box_muller_transform<OutType>(res1, res2, params.sigma, params.mu);
*val = res1;
*(val + 1) = res2;
Expand All @@ -236,7 +241,11 @@ DI void custom_next(GenType& gen,
LenType stride = 0)
{
IntType res1_int, res2_int;
gen.next(res1_int);

do {
gen.next(res1_int);
} while (res1_int == 0);

gen.next(res2_int);
double res1 = static_cast<double>(res1_int);
double res2 = static_cast<double>(res2_int);
Expand All @@ -255,7 +264,11 @@ DI void custom_next(GenType& gen,
LenType stride)
{
OutType res1, res2;
gen.next(res1);

do {
gen.next(res1);
} while (res1 == OutType(0.0));

gen.next(res2);
LenType col1 = idx % params.n_cols;
LenType col2 = (idx + stride) % params.n_cols;
Expand Down Expand Up @@ -294,7 +307,11 @@ DI void custom_next(
GenType& gen, OutType* val, GumbelDistParams<OutType> params, LenType idx = 0, LenType stride = 0)
{
OutType res = 0;
gen.next(res);

do {
gen.next(res);
} while (res == OutType(0.0));

*val = params.mu - params.beta * raft::myLog(-raft::myLog(res));
}

Expand All @@ -306,7 +323,10 @@ DI void custom_next(GenType& gen,
LenType stride = 0)
{
OutType res1 = 0, res2 = 0;
gen.next(res1);
do {
gen.next(res1);
} while (res1 == OutType(0.0));

gen.next(res2);
box_muller_transform<OutType>(res1, res2, params.sigma, params.mu);
*val = raft::myExp(res1);
Expand All @@ -321,7 +341,11 @@ DI void custom_next(GenType& gen,
LenType stride = 0)
{
OutType res;
gen.next(res);

do {
gen.next(res);
} while (res == OutType(0.0));

constexpr OutType one = (OutType)1.0;
*val = params.mu - params.scale * raft::myLog(one / res - one);
}
Expand All @@ -348,6 +372,7 @@ DI void custom_next(GenType& gen,
{
OutType res;
gen.next(res);

constexpr OutType one = (OutType)1.0;
constexpr OutType two = (OutType)2.0;
*val = raft::mySqrt(-two * raft::myLog(one - res)) * params.sigma;
Expand All @@ -361,7 +386,11 @@ DI void custom_next(GenType& gen,
LenType stride = 0)
{
OutType res, out;
gen.next(res);

do {
gen.next(res);
} while (res == OutType(0.0));

constexpr OutType one = (OutType)1.0;
constexpr OutType two = (OutType)2.0;
constexpr OutType oneHalf = (OutType)0.5;
Expand Down Expand Up @@ -451,8 +480,33 @@ struct PhiloxGenerator {
return ret;
}

DI void next(float& ret) { ret = curand_uniform(&(this->philox_state)); }
DI void next(double& ret) { ret = curand_uniform_double(&(this->philox_state)); }
DI float next_float()
{
float ret;
uint32_t val = next_u32() >> 8;
ret = static_cast<float>(val) / (1U << 24);
vinaydes marked this conversation as resolved.
Show resolved Hide resolved
return ret;
}

DI double next_double()
{
double ret;
uint64_t val = next_u64() >> 11;
ret = static_cast<double>(val) / (1LU << 53);
return ret;
}

DI void next(float& ret)
{
// ret = curand_uniform(&(this->philox_state));
ret = next_float();
}

DI void next(double& ret)
{
// ret = curand_uniform_double(&(this->philox_state));
ret = next_double();
}

DI void next(uint32_t& ret) { ret = next_u32(); }
DI void next(uint64_t& ret) { ret = next_u64(); }
Expand Down
2 changes: 1 addition & 1 deletion cpp/test/linalg/gemm_layout.cu
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ const std::vector<GemmLayoutInputs<float>> inputsf = {
{50, 10, 60, false, true, true, 73012ULL},
{90, 90, 30, false, true, false, 538147ULL},
{30, 100, 10, false, false, true, 412352ULL},
{40, 80, 100, false, false, false, 297941ULL}};
{40, 80, 100, false, false, false, 2979410ULL}};

const std::vector<GemmLayoutInputs<double>> inputsd = {
{10, 70, 40, true, true, true, 535648ULL},
Expand Down
Loading