Skip to content

Commit

Permalink
Enhancement on uniform random sampling of indices near zero. (#2153)
Browse files Browse the repository at this point in the history
This is a partial fix for #1979. 

Specifically, given `N = out-deg(v)` and a random number `r ∈ [0,1]`, one must obtain the equivalent discrete 
`index ∈ {0,1,...,N-1}`. Previous implementation used an upper bound `ubound = N-1` and a linear interpolation. As the issue above mentioned that approach creates problems near the (lower) boundary. 

The fix uses a better bound, namely `ubound = N` and the discrete transformation: `index = floor(r >= 1.0 ? N-1 : r*N)`.

Attached Mathematica plots show the graphs for, say, `N = 13` and `N=17`.

![N=13_cropped](https://user-images.githubusercontent.com/37386037/159776745-13c72963-a426-46e2-975f-feedab6bbbb6.png)

![N=17_uniform_sampling](https://user-images.githubusercontent.com/37386037/159775015-203f4442-e2c7-4422-968e-e76807ec9639.png)

This fix is not high priority for release 22-04, and can be included in the 22-06 release. Also, not all of the concerns formulated in the issue above are addressed by this PR. For example a uniform random generator callable from device is not yet available, but there are plans to perhaps expose something like that in `raft`.

Authors:
  - Andrei Schaffer (https://github.com/aschaffer)

Approvers:
  - Chuck Hastings (https://github.com/ChuckHastings)

URL: #2153
  • Loading branch information
aschaffer authored Mar 24, 2022
1 parent 5c75b5b commit f1636a8
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 8 deletions.
7 changes: 3 additions & 4 deletions cpp/src/sampling/random_walks.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,9 @@ struct rrandom_gen_t {
d_ptr_out_degs, // input2
d_ptr_out_degs, // also stencil
d_col_indx.begin(),
[] __device__(real_t rnd_vindx, edge_t crt_out_deg) {
real_t max_ub = static_cast<real_t>(crt_out_deg - 1);
auto interp_vindx = rnd_vindx * max_ub;
vertex_t v_indx = static_cast<vertex_t>(interp_vindx);
[] __device__(real_t rnd_val, edge_t crt_out_deg) {
vertex_t v_indx =
static_cast<vertex_t>(rnd_val >= 1.0 ? crt_out_deg - 1 : rnd_val * crt_out_deg);
return (v_indx >= crt_out_deg ? crt_out_deg - 1 : v_indx);
},
[] __device__(auto crt_out_deg) { return crt_out_deg > 0; });
Expand Down
6 changes: 2 additions & 4 deletions cpp/src/sampling/rw_traversals.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,8 @@ struct uniform_selector_t {
auto crt_out_deg = ptr_d_cache_out_degs_[src_v];
if (crt_out_deg == 0) return thrust::nullopt; // src_v is a sink

real_t max_ub = static_cast<real_t>(crt_out_deg - 1);
auto interp_vindx = rnd_val * max_ub;
vertex_t v_indx = static_cast<vertex_t>(interp_vindx);

vertex_t v_indx =
static_cast<vertex_t>(rnd_val >= 1.0 ? crt_out_deg - 1 : rnd_val * crt_out_deg);
auto col_indx = v_indx >= crt_out_deg ? crt_out_deg - 1 : v_indx;
auto start_row = row_offsets_[src_v];

Expand Down

0 comments on commit f1636a8

Please sign in to comment.