Skip to content

Commit

Permalink
RNG test fixes and improvements (#513)
Browse files Browse the repository at this point in the history
- Fixes the failures of RNG tests and issues mentioned [here](#493).
- Reference standard deviation calculations: https://gist.github.com/vinaydes/cee04f50ff7e3365759603d39b7e079b
- Additionally fixes issue of Rng throwing NaNs and Infs
- Need to add parameter validation for each distribution


@MatthiasKohl @teju85 Please take a look.

Authors:
  - Vinay Deshpande (https://github.com/vinaydes)

Approvers:
  - Matt Joux (https://github.com/MatthiasKohl)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #513
  • Loading branch information
vinaydes authored Mar 8, 2022
1 parent a6f3caf commit 194c7ee
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 152 deletions.
4 changes: 3 additions & 1 deletion cpp/include/raft/random/detail/make_blobs.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ __global__ void generate_data_kernel(DataT* out,
IdxT len = n_rows * n_cols;
for (IdxT idx = tid; idx < len; idx += stride) {
DataT val1, val2;
gen.next(val1);
do {
gen.next(val1);
} while (val1 == DataT(0.0));
gen.next(val2);
DataT mu1, sigma1, mu2, sigma2;
get_mu_sigma(mu1,
Expand Down
79 changes: 68 additions & 11 deletions cpp/include/raft/random/detail/rng_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,13 @@ DI void custom_next(
GenType& gen, OutType* val, NormalDistParams<OutType> params, LenType idx = 0, LenType stride = 0)
{
OutType res1, res2;
gen.next(res1);

do {
gen.next(res1);
} while (res1 == OutType(0.0));

gen.next(res2);

box_muller_transform<OutType>(res1, res2, params.sigma, params.mu);
*val = res1;
*(val + 1) = res2;
Expand All @@ -236,7 +241,11 @@ DI void custom_next(GenType& gen,
LenType stride = 0)
{
IntType res1_int, res2_int;
gen.next(res1_int);

do {
gen.next(res1_int);
} while (res1_int == 0);

gen.next(res2_int);
double res1 = static_cast<double>(res1_int);
double res2 = static_cast<double>(res2_int);
Expand All @@ -255,7 +264,11 @@ DI void custom_next(GenType& gen,
LenType stride)
{
OutType res1, res2;
gen.next(res1);

do {
gen.next(res1);
} while (res1 == OutType(0.0));

gen.next(res2);
LenType col1 = idx % params.n_cols;
LenType col2 = (idx + stride) % params.n_cols;
Expand All @@ -274,7 +287,7 @@ DI void custom_next(
{
Type res = 0;
gen.next(res);
*val = res > params.prob;
*val = res < params.prob;
}

template <typename GenType, typename OutType, typename LenType>
Expand All @@ -286,15 +299,19 @@ DI void custom_next(GenType& gen,
{
OutType res = 0;
gen.next(res);
*val = res > params.prob ? -params.scale : params.scale;
*val = res < params.prob ? -params.scale : params.scale;
}

template <typename GenType, typename OutType, typename LenType>
DI void custom_next(
GenType& gen, OutType* val, GumbelDistParams<OutType> params, LenType idx = 0, LenType stride = 0)
{
OutType res = 0;
gen.next(res);

do {
gen.next(res);
} while (res == OutType(0.0));

*val = params.mu - params.beta * raft::myLog(-raft::myLog(res));
}

Expand All @@ -306,7 +323,10 @@ DI void custom_next(GenType& gen,
LenType stride = 0)
{
OutType res1 = 0, res2 = 0;
gen.next(res1);
do {
gen.next(res1);
} while (res1 == OutType(0.0));

gen.next(res2);
box_muller_transform<OutType>(res1, res2, params.sigma, params.mu);
*val = raft::myExp(res1);
Expand All @@ -321,7 +341,11 @@ DI void custom_next(GenType& gen,
LenType stride = 0)
{
OutType res;
gen.next(res);

do {
gen.next(res);
} while (res == OutType(0.0));

constexpr OutType one = (OutType)1.0;
*val = params.mu - params.scale * raft::myLog(one / res - one);
}
Expand All @@ -348,6 +372,7 @@ DI void custom_next(GenType& gen,
{
OutType res;
gen.next(res);

constexpr OutType one = (OutType)1.0;
constexpr OutType two = (OutType)2.0;
*val = raft::mySqrt(-two * raft::myLog(one - res)) * params.sigma;
Expand All @@ -361,10 +386,17 @@ DI void custom_next(GenType& gen,
LenType stride = 0)
{
OutType res, out;
gen.next(res);

do {
gen.next(res);
} while (res == OutType(0.0));

constexpr OutType one = (OutType)1.0;
constexpr OutType two = (OutType)2.0;
constexpr OutType oneHalf = (OutType)0.5;

// The <= comparison here means, number of samples going in `if` branch are more by 1 than `else`
// branch. However it does not matter as for 0.5 both branches evaluate to same result.
if (res <= oneHalf) {
out = params.mu + params.scale * raft::myLog(two * res);
} else {
Expand Down Expand Up @@ -451,8 +483,33 @@ struct PhiloxGenerator {
return ret;
}

DI void next(float& ret) { ret = curand_uniform(&(this->philox_state)); }
DI void next(double& ret) { ret = curand_uniform_double(&(this->philox_state)); }
DI float next_float()
{
float ret;
uint32_t val = next_u32() >> 8;
ret = static_cast<float>(val) / float(uint32_t(1) << 24);
return ret;
}

DI double next_double()
{
double ret;
uint64_t val = next_u64() >> 11;
ret = static_cast<double>(val) / double(uint64_t(1) << 53);
return ret;
}

DI void next(float& ret)
{
// ret = curand_uniform(&(this->philox_state));
ret = next_float();
}

DI void next(double& ret)
{
// ret = curand_uniform_double(&(this->philox_state));
ret = next_double();
}

DI void next(uint32_t& ret) { ret = next_u32(); }
DI void next(uint64_t& ret) { ret = next_u64(); }
Expand Down
2 changes: 1 addition & 1 deletion cpp/test/linalg/gemm_layout.cu
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ const std::vector<GemmLayoutInputs<float>> inputsf = {
{50, 10, 60, false, true, true, 73012ULL},
{90, 90, 30, false, true, false, 538147ULL},
{30, 100, 10, false, false, true, 412352ULL},
{40, 80, 100, false, false, false, 297941ULL}};
{40, 80, 100, false, false, false, 2979410ULL}};

const std::vector<GemmLayoutInputs<double>> inputsd = {
{10, 70, 40, true, true, true, 535648ULL},
Expand Down
Loading

0 comments on commit 194c7ee

Please sign in to comment.