diff --git a/cpp/src/sampling/random_walks.cuh b/cpp/src/sampling/random_walks.cuh index 0f124f62b1d..ac615bf7b14 100644 --- a/cpp/src/sampling/random_walks.cuh +++ b/cpp/src/sampling/random_walks.cuh @@ -125,10 +125,9 @@ struct rrandom_gen_t { d_ptr_out_degs, // input2 d_ptr_out_degs, // also stencil d_col_indx.begin(), - [] __device__(real_t rnd_vindx, edge_t crt_out_deg) { - real_t max_ub = static_cast(crt_out_deg - 1); - auto interp_vindx = rnd_vindx * max_ub; - vertex_t v_indx = static_cast(interp_vindx); + [] __device__(real_t rnd_val, edge_t crt_out_deg) { + vertex_t v_indx = + static_cast(rnd_val >= 1.0 ? crt_out_deg - 1 : rnd_val * crt_out_deg); return (v_indx >= crt_out_deg ? crt_out_deg - 1 : v_indx); }, [] __device__(auto crt_out_deg) { return crt_out_deg > 0; }); diff --git a/cpp/src/sampling/rw_traversals.hpp b/cpp/src/sampling/rw_traversals.hpp index fc2d45981aa..5b08257880e 100644 --- a/cpp/src/sampling/rw_traversals.hpp +++ b/cpp/src/sampling/rw_traversals.hpp @@ -153,10 +153,8 @@ struct uniform_selector_t { auto crt_out_deg = ptr_d_cache_out_degs_[src_v]; if (crt_out_deg == 0) return thrust::nullopt; // src_v is a sink - real_t max_ub = static_cast(crt_out_deg - 1); - auto interp_vindx = rnd_val * max_ub; - vertex_t v_indx = static_cast(interp_vindx); - + vertex_t v_indx = + static_cast(rnd_val >= 1.0 ? crt_out_deg - 1 : rnd_val * crt_out_deg); auto col_indx = v_indx >= crt_out_deg ? crt_out_deg - 1 : v_indx; auto start_row = row_offsets_[src_v];