Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating raft rng host public API and adding docs #636

Merged
merged 12 commits into from
May 5, 2022
2 changes: 1 addition & 1 deletion cpp/include/raft/linalg/detail/rsvd.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ void rsvdFixedRank(const raft::handle_t& handle,
// build random matrix
rmm::device_uvector<math_t> RN(n * l, stream);
raft::random::RngState state{484};
raft::random::normal(state, RN.data(), n * l, math_t(0.0), alpha, stream);
raft::random::normal(handle, state, RN.data(), n * l, math_t(0.0), alpha);

// multiply to get matrix of random samples Y
rmm::device_uvector<math_t> Y(m * l, stream);
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/raft/random/detail/make_blobs.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ void generate_labels(IdxT* labels,
cudaStream_t stream)
{
IdxT a, b;
affine_transform_params(r, n_clusters, a, b);
raft::random::affine_transform_params(r, n_clusters, a, b);
auto op = [=] __device__(IdxT * ptr, IdxT idx) {
if (shuffle) { idx = IdxT((a * int64_t(idx)) + b); }
idx %= n_clusters;
Expand Down Expand Up @@ -230,7 +230,7 @@ void make_blobs_caller(DataT* out,
const DataT* _centers;
if (centers == nullptr) {
rand_centers.resize(n_clusters * n_cols, stream);
raft::random::uniform(
detail::uniform(
r, rand_centers.data(), n_clusters * n_cols, center_box_min, center_box_max, stream);
_centers = rand_centers.data();
} else {
Expand Down
Loading