Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] Rng: expose host-rng-state in host-only API #609

Merged
merged 18 commits into from
Apr 25, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
17ab145
rng: expose host-rng-state in host-only API
MatthiasKohl Apr 2, 2022
114198c
raft rng: moving back from rng.hpp to rng.cuh
MatthiasKohl Apr 5, 2022
595a5ba
Merge branch 'branch-22.06' into fea-rng-hpp-api
MatthiasKohl Apr 14, 2022
432e786
deprecated rng main header as well as impl
MatthiasKohl Apr 14, 2022
767a834
rng: better separation of functionality into state, launch and device…
MatthiasKohl Apr 14, 2022
d4f52f1
rng: applied new API throughout raft code-base
MatthiasKohl Apr 14, 2022
9f43868
rng: fixed compilation and clang-format
MatthiasKohl Apr 14, 2022
f7f8f62
Merge branch 'branch-22.06' into fea-rng-hpp-api
cjnolet Apr 20, 2022
e91159e
rng API: deprecated rng.hpp only. deprecated attributes for old API, …
MatthiasKohl Apr 22, 2022
98f1a31
rng API: removed rng_launch.cuh, using rng.cuh again
MatthiasKohl Apr 22, 2022
19c319a
Merge branch 'fea-rng-hpp-api' of github.com:MatthiasKohl/raft into f…
MatthiasKohl Apr 22, 2022
3550c90
rng API: clang-format fixes after updates
MatthiasKohl Apr 22, 2022
3c3a45c
rng API: use [[deprecated]] to make doxygen (and clang) happy
MatthiasKohl Apr 22, 2022
da55f83
RNG API: improved includes, copyright
MatthiasKohl Apr 25, 2022
4ffbc90
RNG API: removed handle from sample API
MatthiasKohl Apr 25, 2022
9342f51
RNG API: removed compiler warning to stay non-breaking for downstream…
MatthiasKohl Apr 25, 2022
474dca9
RNG API: improved copyright again
MatthiasKohl Apr 25, 2022
2303377
RNG API: rng.hpp deprecation warning as in spatial/knn.hpp
MatthiasKohl Apr 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions cpp/bench/random/permute.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#include <common/benchmark.hpp>
#include <raft/cudart_utils.h>
#include <raft/random/permute.hpp>
#include <raft/random/rng.hpp>
#include <raft/random/rng_launch.cuh>

#include <rmm/device_uvector.hpp>

Expand All @@ -36,13 +36,13 @@ struct permute : public fixture {
out(p.rows * p.cols, stream),
in(p.rows * p.cols, stream)
{
raft::random::Rng r(123456ULL);
r.uniform(in.data(), p.rows, T(-1.0), T(1.0), stream);
raft::random::RngState r(123456ULL);
uniform(r, in.data(), p.rows, T(-1.0), T(1.0), stream);
}

void run_benchmark(::benchmark::State& state) override
{
raft::random::Rng r(123456ULL);
raft::random::RngState r(123456ULL);
loop_on_state(state, [this, &r]() {
raft::random::permute(
perms.data(), out.data(), in.data(), params.cols, params.rows, params.rowMajor, stream);
Expand Down
22 changes: 11 additions & 11 deletions cpp/bench/random/rng.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

#include <common/benchmark.hpp>
#include <raft/cudart_utils.h>
#include <raft/random/rng.hpp>
#include <raft/random/rng_launch.cuh>

#include <rmm/device_uvector.hpp>

Expand Down Expand Up @@ -48,26 +48,26 @@ struct rng : public fixture {

void run_benchmark(::benchmark::State& state) override
{
raft::random::Rng r(123456ULL, params.gtype);
raft::random::RngState r(123456ULL, params.gtype);
loop_on_state(state, [this, &r]() {
switch (params.type) {
case RNG_Normal: r.normal(ptr.data(), params.len, params.start, params.end, stream); break;
case RNG_Normal: normal(r, ptr.data(), params.len, params.start, params.end, stream); break;
case RNG_LogNormal:
r.lognormal(ptr.data(), params.len, params.start, params.end, stream);
lognormal(r, ptr.data(), params.len, params.start, params.end, stream);
break;
case RNG_Uniform:
r.uniform(ptr.data(), params.len, params.start, params.end, stream);
uniform(r, ptr.data(), params.len, params.start, params.end, stream);
break;
case RNG_Gumbel: r.gumbel(ptr.data(), params.len, params.start, params.end, stream); break;
case RNG_Gumbel: gumbel(r, ptr.data(), params.len, params.start, params.end, stream); break;
case RNG_Logistic:
r.logistic(ptr.data(), params.len, params.start, params.end, stream);
logistic(r, ptr.data(), params.len, params.start, params.end, stream);
break;
case RNG_Exp: r.exponential(ptr.data(), params.len, params.start, stream); break;
case RNG_Rayleigh: r.rayleigh(ptr.data(), params.len, params.start, stream); break;
case RNG_Exp: exponential(r, ptr.data(), params.len, params.start, stream); break;
case RNG_Rayleigh: rayleigh(r, ptr.data(), params.len, params.start, stream); break;
case RNG_Laplace:
r.laplace(ptr.data(), params.len, params.start, params.end, stream);
laplace(r, ptr.data(), params.len, params.start, params.end, stream);
break;
case RNG_Fill: r.fill(ptr.data(), params.len, params.start, stream); break;
case RNG_Fill: fill(r, ptr.data(), params.len, params.start, stream); break;
};
});
}
Expand Down
8 changes: 4 additions & 4 deletions cpp/bench/spatial/fused_l2_nn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#include <raft/distance/fused_l2_nn.hpp>
#include <raft/handle.hpp>
#include <raft/linalg/norm.hpp>
#include <raft/random/rng.hpp>
#include <raft/random/rng_launch.cuh>

#if defined RAFT_NN_COMPILED
#include <raft/spatial/knn/specializations.hpp>
Expand All @@ -44,10 +44,10 @@ struct fused_l2_nn : public fixture {
workspace(p.m, stream)
{
raft::handle_t handle{stream};
raft::random::Rng r(123456ULL);
raft::random::RngState r(123456ULL);

r.uniform(x.data(), p.m * p.k, T(-1.0), T(1.0), stream);
r.uniform(y.data(), p.n * p.k, T(-1.0), T(1.0), stream);
uniform(r, x.data(), p.m * p.k, T(-1.0), T(1.0), stream);
uniform(r, y.data(), p.n * p.k, T(-1.0), T(1.0), stream);
raft::linalg::rowNorm(xn.data(), x.data(), p.k, p.m, raft::linalg::L2Norm, true, stream);
raft::linalg::rowNorm(yn.data(), y.data(), p.k, p.n, raft::linalg::L2Norm, true, stream);
raft::distance::initialize<T, cub::KeyValuePair<int, T>, int>(
Expand Down
6 changes: 3 additions & 3 deletions cpp/bench/spatial/selection.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#include <raft/spatial/knn/specializations.hpp>
#endif

#include <raft/random/rng.hpp>
#include <raft/random/rng_launch.cuh>
#include <raft/sparse/detail/utils.h>

#include <rmm/mr/device/per_device_resource.hpp>
Expand All @@ -46,8 +46,8 @@ struct selection : public fixture {
out_ids_(p.n_inputs * p.k, stream)
{
raft::sparse::iota_fill(in_ids_.data(), IdxT(p.n_inputs), IdxT(p.input_len), stream);
raft::random::Rng(42).uniform(
in_dists_.data(), in_dists_.size(), KeyT(-1.0), KeyT(1.0), stream);
raft::random::RngState state{42};
raft::random::uniform(state, in_dists_.data(), in_dists_.size(), KeyT(-1.0), KeyT(1.0), stream);
}

void run_benchmark(::benchmark::State& state) override
Expand Down
1 change: 0 additions & 1 deletion cpp/include/raft/linalg/detail/lstsq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
#include <raft/linalg/transpose.cuh>
#include <raft/matrix/math.cuh>
#include <raft/matrix/matrix.cuh>
#include <raft/random/rng.cuh>
#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_scalar.hpp>
#include <rmm/device_uvector.hpp>
Expand Down
6 changes: 3 additions & 3 deletions cpp/include/raft/linalg/detail/rsvd.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#include <raft/linalg/transpose.cuh>
#include <raft/matrix/math.cuh>
#include <raft/matrix/matrix.cuh>
#include <raft/random/rng.cuh>
#include <raft/random/rng_launch.cuh>

#include <algorithm>

Expand Down Expand Up @@ -92,8 +92,8 @@ void rsvdFixedRank(const raft::handle_t& handle,

// build random matrix
rmm::device_uvector<math_t> RN(n * l, stream);
raft::random::Rng rng(484);
rng.normal(RN.data(), n * l, math_t(0.0), alpha, stream);
raft::random::RngState state{484};
raft::random::normal(state, RN.data(), n * l, math_t(0.0), alpha, stream);

// multiply to get matrix of random samples Y
rmm::device_uvector<math_t> Y(m * l, stream);
Expand Down
47 changes: 26 additions & 21 deletions cpp/include/raft/random/detail/make_blobs.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
#include <raft/cuda_utils.cuh>
#include <raft/cudart_utils.h>
#include <raft/linalg/unary_op.cuh>
#include <raft/random/rng.cuh>
#include <raft/random/rng_device.cuh>
#include <raft/random/rng_launch.cuh>
#include <rmm/device_uvector.hpp>
#include <vector>

Expand All @@ -35,11 +36,11 @@ void generate_labels(IdxT* labels,
IdxT n_rows,
IdxT n_clusters,
bool shuffle,
raft::random::Rng& r,
raft::random::RngState& r,
cudaStream_t stream)
{
IdxT a, b;
r.affine_transform_params(n_clusters, a, b);
affine_transform_params(r, n_clusters, a, b);
auto op = [=] __device__(IdxT * ptr, IdxT idx) {
if (shuffle) { idx = IdxT((a * int64_t(idx)) + b); }
idx %= n_clusters;
Expand Down Expand Up @@ -89,20 +90,20 @@ DI void get_mu_sigma(DataT& mu,
mu = centers[center_id];
}

template <typename DataT, typename IdxT>
__global__ void generate_data_kernel(DataT* out,
template <typename DataT, typename IdxT, typename GenType>
__global__ void generate_data_kernel(raft::random::DeviceState<GenType> rng_state,
DataT* out,
const IdxT* labels,
IdxT n_rows,
IdxT n_cols,
IdxT n_clusters,
bool row_major,
const DataT* centers,
const DataT* cluster_std,
const DataT cluster_std_scalar,
raft::random::RngState rng_state)
const DataT cluster_std_scalar)
{
uint64_t tid = (blockIdx.x * blockDim.x) + threadIdx.x;
raft::random::PhiloxGenerator gen(rng_state, tid);
GenType gen(rng_state, tid);
const IdxT stride = gridDim.x * blockDim.x;
IdxT len = n_rows * n_cols;
for (IdxT idx = tid; idx < len; idx += stride) {
Expand Down Expand Up @@ -157,16 +158,19 @@ void generate_data(DataT* out,
{
IdxT items = n_rows * n_cols;
IdxT nBlocks = (items + 127) / 128;
generate_data_kernel<<<nBlocks, 128, 0, stream>>>(out,
labels,
n_rows,
n_cols,
n_clusters,
row_major,
centers,
cluster_std,
cluster_std_scalar,
rng_state);
// parentheses needed here for kernel, otherwise macro interprets the arguments
// of triple chevron notation as macro arguments
RAFT_CALL_RNG_FUNC(rng_state,
(generate_data_kernel<<<nBlocks, 128, 0, stream>>>),
out,
labels,
n_rows,
n_cols,
n_clusters,
row_major,
centers,
cluster_std,
cluster_std_scalar);
}

/**
Expand Down Expand Up @@ -220,13 +224,14 @@ void make_blobs_caller(DataT* out,
uint64_t seed,
raft::random::GeneratorType type)
{
raft::random::Rng r(seed, type);
raft::random::RngState r(seed, type);
// use the right centers buffer for data generation
rmm::device_uvector<DataT> rand_centers(0, stream);
const DataT* _centers;
if (centers == nullptr) {
rand_centers.resize(n_clusters * n_cols, stream);
r.uniform(rand_centers.data(), n_clusters * n_cols, center_box_min, center_box_max, stream);
raft::random::uniform(
r, rand_centers.data(), n_clusters * n_cols, center_box_min, center_box_max, stream);
_centers = rand_centers.data();
} else {
_centers = centers;
Expand All @@ -242,7 +247,7 @@ void make_blobs_caller(DataT* out,
_centers,
cluster_std,
cluster_std_scalar,
r.state);
r);
}

} // end namespace detail
Expand Down
18 changes: 9 additions & 9 deletions cpp/include/raft/random/detail/make_regression.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
#include <raft/linalg/transpose.cuh>
#include <raft/matrix/matrix.cuh>
#include <raft/random/permute.cuh>
#include <raft/random/rng.cuh>
#include <raft/random/rng_launch.cuh>
#include <rmm/device_uvector.hpp>

namespace raft::random {
Expand All @@ -58,7 +58,7 @@ static void _make_low_rank_matrix(const raft::handle_t& handle,
IdxT n_cols,
IdxT effective_rank,
DataT tail_strength,
raft::random::Rng& r,
raft::random::RngState& r,
cudaStream_t stream)
{
cusolverDnHandle_t cusolver_handle = handle.get_cusolver_dn_handle();
Expand All @@ -69,8 +69,8 @@ static void _make_low_rank_matrix(const raft::handle_t& handle,
// Generate random (ortho normal) vectors with QR decomposition
rmm::device_uvector<DataT> rd_mat_0(n_rows * n, stream);
rmm::device_uvector<DataT> rd_mat_1(n_cols * n, stream);
r.normal(rd_mat_0.data(), n_rows * n, (DataT)0.0, (DataT)1.0, stream);
r.normal(rd_mat_1.data(), n_cols * n, (DataT)0.0, (DataT)1.0, stream);
normal(r, rd_mat_0.data(), n_rows * n, (DataT)0.0, (DataT)1.0, stream);
normal(r, rd_mat_1.data(), n_cols * n, (DataT)0.0, (DataT)1.0, stream);
rmm::device_uvector<DataT> q0(n_rows * n, stream);
rmm::device_uvector<DataT> q1(n_cols * n, stream);
raft::linalg::qrGetQ(handle, rd_mat_0.data(), q0.data(), n_rows, n, stream);
Expand Down Expand Up @@ -166,11 +166,11 @@ void make_regression_caller(const raft::handle_t& handle,
cublasHandle_t cublas_handle = handle.get_cublas_handle();

cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_HOST);
raft::random::Rng r(seed, type);
raft::random::RngState r(seed, type);

if (effective_rank < 0) {
// Randomly generate a well conditioned input set
r.normal(out, n_rows * n_cols, (DataT)0.0, (DataT)1.0, stream);
normal(r, out, n_rows * n_cols, (DataT)0.0, (DataT)1.0, stream);
} else {
// Randomly generate a low rank, fat tail input set
_make_low_rank_matrix(handle, out, n_rows, n_cols, effective_rank, tail_strength, r, stream);
Expand Down Expand Up @@ -207,7 +207,7 @@ void make_regression_caller(const raft::handle_t& handle,
}

// Generate a ground truth model with only n_informative features
r.uniform(_coef, n_informative * n_targets, (DataT)1.0, (DataT)100.0, stream);
uniform(r, _coef, n_informative * n_targets, (DataT)1.0, (DataT)100.0, stream);
if (coef && n_informative != n_cols) {
RAFT_CUDA_TRY(cudaMemsetAsync(_coef + n_informative * n_targets,
0,
Expand Down Expand Up @@ -247,7 +247,7 @@ void make_regression_caller(const raft::handle_t& handle,
if (noise != 0.0) {
// Add white noise
white_noise.resize(n_rows * n_targets, stream);
r.normal(white_noise.data(), n_rows * n_targets, (DataT)0.0, noise, stream);
normal(r, white_noise.data(), n_rows * n_targets, (DataT)0.0, noise, stream);
raft::linalg::add(_values, _values, white_noise.data(), n_rows * n_targets, stream);
}

Expand Down Expand Up @@ -281,4 +281,4 @@ void make_regression_caller(const raft::handle_t& handle,
}

} // namespace detail
} // namespace raft::random
} // namespace raft::random
Loading