diff --git a/cpp/bench/prims/CMakeLists.txt b/cpp/bench/prims/CMakeLists.txt index 5577881ef7..4180025f77 100644 --- a/cpp/bench/prims/CMakeLists.txt +++ b/cpp/bench/prims/CMakeLists.txt @@ -136,7 +136,7 @@ if(BUILD_PRIMS_BENCH) ConfigureBench( NAME RANDOM_BENCH PATH bench/prims/random/make_blobs.cu bench/prims/random/permute.cu - bench/prims/random/rng.cu bench/prims/main.cpp + bench/prims/random/rng.cu bench/prims/random/subsample.cu bench/prims/main.cpp ) ConfigureBench(NAME SPARSE_BENCH PATH bench/prims/sparse/convert_csr.cu bench/prims/main.cpp) diff --git a/cpp/bench/prims/matrix/gather.cu b/cpp/bench/prims/matrix/gather.cu index e6f26ba925..078f9e6198 100644 --- a/cpp/bench/prims/matrix/gather.cu +++ b/cpp/bench/prims/matrix/gather.cu @@ -16,34 +16,48 @@ #include +#include +#include #include #include #include #include #include +#include namespace raft::bench::matrix { template struct GatherParams { IdxT rows, cols, map_length; + bool host; }; template inline auto operator<<(std::ostream& os, const GatherParams& p) -> std::ostream& { - os << p.rows << "#" << p.cols << "#" << p.map_length; + os << p.rows << "#" << p.cols << "#" << p.map_length << (p.host ? "#host" : "#device"); return os; } template struct Gather : public fixture { Gather(const GatherParams& p) - : params(p), matrix(this->handle), map(this->handle), out(this->handle), stencil(this->handle) + : params(p), + old_mr(rmm::mr::get_current_device_resource()), + pool_mr(rmm::mr::get_current_device_resource(), 2 * (1ULL << 30)), + matrix(this->handle), + map(this->handle), + out(this->handle), + stencil(this->handle), + matrix_h(this->handle) { + rmm::mr::set_current_device_resource(&pool_mr); } + ~Gather() { rmm::mr::set_current_device_resource(old_mr); } + void allocate_data(const ::benchmark::State& state) override { matrix = raft::make_device_matrix(handle, params.rows, params.cols); @@ -59,6 +73,11 @@ struct Gather : public fixture { if constexpr (Conditional) { raft::random::uniform(handle, rng, stencil.data_handle(), params.map_length, T(-1), T(1)); } + + if (params.host) { + matrix_h = raft::make_host_matrix(handle, params.rows, params.cols); + raft::copy(matrix_h.data_handle(), matrix.data_handle(), matrix.size(), stream); + } resource::sync_stream(handle, stream); } @@ -77,14 +96,22 @@ struct Gather : public fixture { raft::matrix::gather_if( handle, matrix_const_view, out.view(), map_const_view, stencil_const_view, pred_op); } else { - raft::matrix::gather(handle, matrix_const_view, map_const_view, out.view()); + if (params.host) { + raft::matrix::detail::gather( + handle, make_const_mdspan(matrix_h.view()), map_const_view, out.view()); + } else { + raft::matrix::gather(handle, matrix_const_view, map_const_view, out.view()); + } } }); } private: GatherParams params; + rmm::mr::device_memory_resource* old_mr; + rmm::mr::pool_memory_resource pool_mr; raft::device_matrix matrix, out; + raft::host_matrix matrix_h; raft::device_vector stencil; raft::device_vector map; }; // struct Gather @@ -100,4 +127,9 @@ RAFT_BENCH_REGISTER((Gather), "", gather_inputs_i64); RAFT_BENCH_REGISTER((Gather), "", gather_inputs_i64); RAFT_BENCH_REGISTER((GatherIf), "", gather_inputs_i64); RAFT_BENCH_REGISTER((GatherIf), "", gather_inputs_i64); + +auto inputs_host = raft::util::itertools::product>( + {10000000}, {100}, {1000, 1000000, 10000000}, {true}); +RAFT_BENCH_REGISTER((Gather), "Host", inputs_host); + } // namespace raft::bench::matrix diff --git a/cpp/bench/prims/random/subsample.cu b/cpp/bench/prims/random/subsample.cu new file mode 100644 index 0000000000..4c8ca2bf31 --- /dev/null +++ b/cpp/bench/prims/random/subsample.cu @@ -0,0 +1,112 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +namespace raft::bench::random { + +struct sample_inputs { + int n_samples; + int n_train; + int method; +}; // struct sample_inputs + +inline auto operator<<(std::ostream& os, const sample_inputs& p) -> std::ostream& +{ + os << p.n_samples << "#" << p.n_train << "#" << p.method; + return os; +} + +// Sample with replacement. We use this as a baseline. +template +auto bernoulli_subsample(raft::resources const& res, IdxT n_samples, IdxT n_subsamples, int seed) + -> raft::device_vector +{ + RAFT_EXPECTS(n_subsamples <= n_samples, "Cannot have more training samples than dataset vectors"); + + auto indices = raft::make_device_vector(res, n_subsamples); + raft::random::RngState state(123456ULL); + raft::random::uniformInt( + res, state, indices.data_handle(), n_subsamples, IdxT(0), IdxT(n_samples)); + return indices; +} + +template +struct sample : public fixture { + sample(const sample_inputs& p) + : params(p), + old_mr(rmm::mr::get_current_device_resource()), + pool_mr(rmm::mr::get_current_device_resource(), 2 * GiB), + in(make_device_vector(res, p.n_samples)), + out(make_device_vector(res, p.n_train)) + { + rmm::mr::set_current_device_resource(&pool_mr); + raft::random::RngState r(123456ULL); + } + + ~sample() { rmm::mr::set_current_device_resource(old_mr); } + void run_benchmark(::benchmark::State& state) override + { + std::ostringstream label_stream; + label_stream << params; + state.SetLabel(label_stream.str()); + + raft::random::RngState r(123456ULL); + loop_on_state(state, [this, &r]() { + if (params.method == 1) { + this->out = + bernoulli_subsample(this->res, this->params.n_samples, this->params.n_train, 137); + } else if (params.method == 2) { + this->out = raft::random::excess_subsample( + this->res, r, this->params.n_samples, this->params.n_train); + } + }); + } + + private: + float GiB = 1073741824.0f; + raft::device_resources res; + rmm::mr::device_memory_resource* old_mr; + rmm::mr::pool_memory_resource pool_mr; + sample_inputs params; + raft::device_vector out, in; +}; // struct sample + +const std::vector input_vecs = {{100000000, 10000000, 1}, + {100000000, 50000000, 1}, + {100000000, 100000000, 1}, + {100000000, 10000000, 2}, + {100000000, 50000000, 2}, + {100000000, 100000000, 2}}; + +RAFT_BENCH_REGISTER(sample, "", input_vecs); + +} // namespace raft::bench::random diff --git a/cpp/include/raft/matrix/detail/gather.cuh b/cpp/include/raft/matrix/detail/gather.cuh index 651fec81c3..05cc9204bf 100644 --- a/cpp/include/raft/matrix/detail/gather.cuh +++ b/cpp/include/raft/matrix/detail/gather.cuh @@ -16,9 +16,19 @@ #pragma once +#include +#include +#include +#include +#include #include +#include +#include +#include #include +#include + #include namespace raft { @@ -336,6 +346,83 @@ void gather_if(const InputIteratorT in, gatherImpl(in, D, N, map, stencil, map_length, out, pred_op, transform_op, stream); } +/** + * Helper function to gather a set of vectors from a (host) dataset. + */ +template +void gather_buff(host_matrix_view dataset, + host_vector_view indices, + MatIdxT offset, + pinned_matrix_view buff) +{ + raft::common::nvtx::range fun_scope("gather_host_buff"); + IdxT batch_size = std::min(buff.extent(0), indices.extent(0) - offset); + +#pragma omp for + for (IdxT i = 0; i < batch_size; i++) { + IdxT in_idx = indices(offset + i); + for (IdxT k = 0; k < buff.extent(1); k++) { + buff(i, k) = dataset(in_idx, k); + } + } +} + +template +void gather(raft::resources const& res, + host_matrix_view dataset, + device_vector_view indices, + raft::device_matrix_view output) +{ + raft::common::nvtx::range fun_scope("gather"); + IdxT n_dim = output.extent(1); + IdxT n_train = output.extent(0); + auto indices_host = raft::make_host_vector(n_train); + raft::copy( + indices_host.data_handle(), indices.data_handle(), n_train, resource::get_cuda_stream(res)); + resource::sync_stream(res); + + const size_t buffer_size = 32768 * 1024; // bytes + const size_t max_batch_size = + std::min(round_up_safe(buffer_size / n_dim, 32), n_train); + RAFT_LOG_DEBUG("Gathering data with batch size %zu", max_batch_size); + + // Gather the vector on the host in tmp buffers. We use two buffers to overlap H2D sync + // and gathering the data. + auto out_tmp1 = raft::make_pinned_matrix(res, max_batch_size, n_dim); + auto out_tmp2 = raft::make_pinned_matrix(res, max_batch_size, n_dim); + + // Usually a limited number of threads provide sufficient bandwidth for gathering data. + int n_threads = std::min(omp_get_max_threads(), 32); + + // The gather_buff function has a parallel for loop. We start the the omp parallel + // region here, to avoid repeated overhead within the device_offset loop. +#pragma omp parallel num_threads(n_threads) + { + auto view1 = out_tmp1.view(); + auto view2 = out_tmp2.view(); + gather_buff(dataset, make_const_mdspan(indices_host.view()), (MatIdxT)0, view1); + for (MatIdxT device_offset = 0; device_offset < n_train; device_offset += max_batch_size) { + MatIdxT batch_size = std::min(max_batch_size, n_train - device_offset); + +#pragma omp master + raft::copy(output.data_handle() + device_offset * n_dim, + view1.data_handle(), + batch_size * n_dim, + resource::get_cuda_stream(res)); + // Start gathering the next batch on the host. + MatIdxT host_offset = device_offset + batch_size; + batch_size = std::min(max_batch_size, n_train - host_offset); + if (batch_size > 0) { + gather_buff(dataset, make_const_mdspan(indices_host.view()), host_offset, view2); + } +#pragma omp master + resource::sync_stream(res); +#pragma omp barrier + std::swap(view1, view2); + } + } +} + } // namespace detail } // namespace matrix } // namespace raft diff --git a/cpp/include/raft/matrix/detail/sample_rows.cuh b/cpp/include/raft/matrix/detail/sample_rows.cuh new file mode 100644 index 0000000000..e28ad648da --- /dev/null +++ b/cpp/include/raft/matrix/detail/sample_rows.cuh @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft::matrix::detail { + +/** Select rows randomly from input and copy to output. */ +template +void sample_rows(raft::resources const& res, + random::RngState random_state, + const T* input, + IdxT n_rows_input, + raft::device_matrix_view output) +{ + IdxT n_dim = output.extent(1); + IdxT n_samples = output.extent(0); + + raft::device_vector train_indices = + raft::random::excess_subsample(res, random_state, n_rows_input, n_samples); + + cudaPointerAttributes attr; + RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, input)); + T* ptr = reinterpret_cast(attr.devicePointer); + if (ptr != nullptr) { + raft::matrix::gather(res, + raft::make_device_matrix_view(ptr, n_rows_input, n_dim), + raft::make_const_mdspan(train_indices.view()), + output); + } else { + auto dataset = raft::make_host_matrix_view(input, n_rows_input, n_dim); + raft::matrix::detail::gather(res, dataset, make_const_mdspan(train_indices.view()), output); + } +} +} // namespace raft::matrix::detail diff --git a/cpp/include/raft/matrix/sample_rows.cuh b/cpp/include/raft/matrix/sample_rows.cuh new file mode 100644 index 0000000000..7925d344e4 --- /dev/null +++ b/cpp/include/raft/matrix/sample_rows.cuh @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace raft::matrix { + +/** @brief Select rows randomly from input and copy to output. + * + * The rows are selected randomly. The random sampling method does not guarantee completely unique + * selection of rows, but it is close to being unique. + * + * @param res RAFT resource handle + * @param random_state + * @param dataset input dataset + * @param output subsampled dataset + */ +template +void sample_rows(raft::resources const& res, + random::RngState random_state, + mdspan, row_major, accessor> dataset, + raft::device_matrix_view output) +{ + RAFT_EXPECTS(dataset.extent(1) == output.extent(1), + "dataset dims must match, but received %ld vs %ld", + static_cast(dataset.extent(1)), + static_cast(output.extent(1))); + detail::sample_rows(res, random_state, dataset.data_handle(), dataset.extent(0), output); +} + +/** @brief Select rows randomly from input and copy to output. + * + * The rows are selected randomly. The random sampling method does not guarantee completely unique + * selection of rows, but it is close to being unique. + * + * @param res RAFT resource handle + * @param random_state + * @param dataset input dataset + * @param n_samples number of rows in the returned matrix + * + * @return subsampled dataset + * */ +template +raft::device_matrix sample_rows( + raft::resources const& res, + random::RngState random_state, + mdspan, row_major, accessor> dataset, + IdxT n_samples) +{ + auto output = raft::make_device_matrix(res, n_samples, dataset.extent(1)); + sample_rows(res, random_state, dataset, output.view()); + return output; +} + +} // namespace raft::matrix diff --git a/cpp/include/raft/random/detail/rng_impl.cuh b/cpp/include/raft/random/detail/rng_impl.cuh index 57f4c8d33d..61a944e9b6 100644 --- a/cpp/include/raft/random/detail/rng_impl.cuh +++ b/cpp/include/raft/random/detail/rng_impl.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * Copyright (c) 2018-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,12 +17,20 @@ #pragma once #include +#include +#include +#include +#include #include #include #include #include #include +#include + +#include + namespace raft { namespace random { namespace detail { @@ -278,6 +286,7 @@ std::enable_if_t> discrete(RngState& rng_state, len); } +/** Note the memory space requirements are O(4*len) */ template void sampleWithoutReplacement(RngState& rng_state, DataT* out, @@ -328,6 +337,133 @@ void affine_transform_params(RngState const& rng_state, IdxT n, IdxT& a, IdxT& b b = mt_rng() % n; } +/** @brief Sample without replacement from range 0..N-1. + * + * Elements are sampled uniformly. + * The algorithm will allocate a workspace of size O(4*n_samples) internally. + * + * We use max N random numbers. Depending on how large n_samples is w.r.t to N, we + * either use rejection sampling, or sort the [0..N-1] values using random keys. + * + * @tparam IdxT type of indices that we sample + * @tparam MatIdxT extent type of the returned mdarray + * + * @param res RAFT resource handle + * @param state random number generator state + * @param N number of elements to sample from. We will sample values in range 0..N-1 + * @param n_samples number of samples to return + * + * @return device mdarray with the random samples + */ +template +auto excess_subsample(raft::resources const& res, RngState& state, IdxT N, IdxT n_samples) + -> raft::device_vector +{ + RAFT_EXPECTS(n_samples <= N, "Cannot have more training samples than dataset vectors"); + + // Number of samples we'll need to sample (with replacement), to expect 'k' + // unique samples from 'n' is given by the following equation: log(1 - k/n)/log(1 - 1/n) ref: + // https://stats.stackexchange.com/questions/296005/the-expected-number-of-unique-elements-drawn-with-replacement + IdxT n_excess_samples = + n_samples < N + ? std::ceil(raft::log(1 - double(n_samples) / double(N)) / (raft::log(1 - 1 / double(N)))) + : N; + + // There is a variance of n_excess_samples, we take 10% more elements. + n_excess_samples += std::max(0.1 * n_samples, 100); + + // n_excess_sampless will be larger than N around k = 0.64*N. When we reach N, then instead of + // doing rejection sampling, we simply shuffle the range [0..N-1] using N random numbers. + n_excess_samples = std::min(n_excess_samples, N); + auto rnd_idx = raft::make_device_vector(res, n_excess_samples); + + auto linear_idx = raft::make_device_vector(res, rnd_idx.size()); + raft::linalg::map_offset(res, linear_idx.view(), identity_op()); + + uniformInt(res, state, rnd_idx.data_handle(), rnd_idx.size(), IdxT(0), IdxT(N)); + + // Sort indices according to rnd keys + size_t workspace_size = 0; + auto stream = resource::get_cuda_stream(res); + cub::DeviceMergeSort::SortPairs(nullptr, + workspace_size, + rnd_idx.data_handle(), + linear_idx.data_handle(), + rnd_idx.size(), + raft::less_op{}, + stream); + auto workspace = raft::make_device_vector(res, workspace_size); + cub::DeviceMergeSort::SortPairs(workspace.data_handle(), + workspace_size, + rnd_idx.data_handle(), + linear_idx.data_handle(), + rnd_idx.size(), + raft::less_op{}, + stream); + + if (rnd_idx.size() == static_cast(N)) { + // We shuffled the linear_idx array by sorting it according to rnd_idx. + // We return the first n_samples elements. + if (n_samples == N) { return linear_idx; } + rnd_idx = raft::make_device_vector(res, n_samples); + raft::copy(rnd_idx.data_handle(), linear_idx.data_handle(), n_samples, stream); + return rnd_idx; + } + // Else we do a rejection sampling (or excess sampling): we generated more random indices than + // needed and reject the duplicates. + auto keys_out = raft::make_device_vector(res, rnd_idx.size()); + auto values_out = raft::make_device_vector(res, rnd_idx.size()); + rmm::device_scalar num_selected(stream); + size_t worksize2 = 0; + cub::DeviceSelect::UniqueByKey(nullptr, + worksize2, + rnd_idx.data_handle(), + linear_idx.data_handle(), + keys_out.data_handle(), + values_out.data_handle(), + num_selected.data(), + rnd_idx.size(), + stream); + + if (worksize2 > workspace.size()) { + workspace = raft::make_device_vector(res, worksize2); + workspace_size = workspace.size(); + } + + cub::DeviceSelect::UniqueByKey(workspace.data_handle(), + workspace_size, + rnd_idx.data_handle(), + linear_idx.data_handle(), + keys_out.data_handle(), + values_out.data_handle(), + num_selected.data(), + rnd_idx.size(), + stream); + + IdxT selected = num_selected.value(stream); + + if (selected < n_samples) { + RAFT_LOG_DEBUG("Subsampling returned with less unique indices (%zu) than requested (%zu)", + (size_t)selected, + (size_t)n_samples); + + // We continue to select n_samples elements, this will now contains a few duplicates. + } + + // After duplicates are removed, we need to shuffle back to random order + cub::DeviceMergeSort::SortPairs(workspace.data_handle(), + workspace_size, + values_out.data_handle(), + keys_out.data_handle(), + n_samples, + raft::less_op{}, + stream); + + values_out = raft::make_device_vector(res, n_samples); + raft::copy(values_out.data_handle(), keys_out.data_handle(), n_samples, stream); + return values_out; +} + }; // end namespace detail }; // end namespace random }; // end namespace raft diff --git a/cpp/include/raft/random/rng.cuh b/cpp/include/raft/random/rng.cuh index 4e63669f98..7fd461980f 100644 --- a/cpp/include/raft/random/rng.cuh +++ b/cpp/include/raft/random/rng.cuh @@ -813,6 +813,32 @@ void sampleWithoutReplacement(raft::resources const& handle, rng_state, out, outIdx, in, wts, sampledLen, len, resource::get_cuda_stream(handle)); } +/** @brief Sample from range 0..N-1. + * + * Elements are sampled uniformly. The method aims to sample without replacement, + * but there is a small probability of a few having duplicate elements. + * + * The algorithm will allocate a workspace of size 4*n_samples*sizeof(IdxT) internally. + * + * We use max N random numbers. Depending on how large n_samples is w.r.t to N, we + * either use rejection sampling, or sort the [0..N-1] values using random keys. + * + * @tparam IdxT type of indices that we sample + * @tparam MatIdxT extent type of the returned mdarray + * + * @param res RAFT resource handle + * @param state random number generator state + * @param N number of elements to sample from. We will sample values in range 0..N-1. + * @param n_samples number of samples to return + * + * @return device mdarray with the random samples + */ +template +auto excess_subsample(raft::resources const& res, RngState& state, IdxT N, IdxT n_samples) +{ + return detail::excess_subsample(res, state, N, n_samples); +} + /** * @brief Generates the 'a' and 'b' parameters for a modulo affine * transformation equation: `(ax + b) % n` diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 65d6e738a2..c27d7388af 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -266,6 +266,7 @@ if(BUILD_TESTS) test/matrix/matrix.cu test/matrix/norm.cu test/matrix/reverse.cu + test/matrix/sample_rows.cu test/matrix/slice.cu test/matrix/triangular.cu test/sparse/spectral_matrix.cu @@ -293,6 +294,7 @@ if(BUILD_TESTS) test/random/rng_int.cu test/random/rmat_rectangular_generator.cu test/random/sample_without_replacement.cu + test/random/excess_sampling.cu ) ConfigureTest( diff --git a/cpp/test/matrix/sample_rows.cu b/cpp/test/matrix/sample_rows.cu new file mode 100644 index 0000000000..e332a918fe --- /dev/null +++ b/cpp/test/matrix/sample_rows.cu @@ -0,0 +1,140 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +namespace raft { +namespace matrix { + +struct inputs { + int N; + int dim; + int n_samples; + bool host; +}; + +::std::ostream& operator<<(::std::ostream& os, const inputs p) +{ + os << p.N << "#" << p.dim << "#" << p.n_samples << (p.host ? "#host" : "#device"); + return os; +} + +template +class SampleRowsTest : public ::testing::TestWithParam { + public: + SampleRowsTest() + : params(::testing::TestWithParam::GetParam()), + stream(resource::get_cuda_stream(res)), + state{137ULL}, + in(make_device_matrix(res, params.N, params.dim)), + out(make_device_matrix(res, 0, 0)), + in_h(make_host_matrix(res, params.N, params.dim)), + out_h(make_host_matrix(res, params.n_samples, params.dim)) + { + raft::random::uniform(res, state, in.data_handle(), in.size(), T(-1.0), T(1.0)); + for (int64_t i = 0; i < params.N; i++) { + for (int64_t k = 0; k < params.dim; k++) + in_h(i, k) = i * 1000 + k; + } + raft::copy(in.data_handle(), in_h.data_handle(), in_h.size(), stream); + } + + void check() + { + if (params.host) { + out = raft::matrix::sample_rows( + res, state, make_const_mdspan(in_h.view()), (int64_t)params.n_samples); + } else { + out = raft::matrix::sample_rows( + res, state, make_const_mdspan(in.view()), (int64_t)params.n_samples); + } + + raft::copy(out_h.data_handle(), out.data_handle(), out.size(), stream); + resource::sync_stream(res, stream); + + ASSERT_TRUE(out.extent(0) == params.n_samples); + ASSERT_TRUE(out.extent(1) == params.dim); + + std::unordered_set occurrence; + + for (int64_t i = 0; i < params.n_samples; ++i) { + T val = out_h(i, 0) / 1000; + ASSERT_TRUE(0 <= val && val < params.N) + << "out-of-range index @i=" << i << " val=" << val << " params=" << params; + EXPECT_TRUE(occurrence.find(val) == occurrence.end()) + << "repeated index @i=" << i << " idx=" << val << " params=" << params; + occurrence.insert(val); + for (int64_t k = 0; k < params.dim; k++) { + ASSERT_TRUE(raft::match(out_h(i, k), val * 1000 + k, raft::CompareApprox(1e-6))); + } + } + } + + protected: + inputs params; + raft::resources res; + cudaStream_t stream; + random::RngState state; + device_matrix in, out; + host_matrix in_h, out_h; +}; + +inline std::vector generate_inputs() +{ + std::vector input1 = + raft::util::itertools::product({10}, {1, 17, 96}, {1, 6, 9, 10}, {false}); + + std::vector input2 = + raft::util::itertools::product({137}, {1, 17, 128}, {1, 10, 100, 137}, {false}); + input1.insert(input1.end(), input2.begin(), input2.end()); + + input2 = raft::util::itertools::product( + {100000}, {1, 42}, {1, 137, 1000, 10000, 50000, 62000, 100000}, {false}); + + input1.insert(input1.end(), input2.begin(), input2.end()); + + int n = input1.size(); + // Add same tests for host data + for (int i = 0; i < n; i++) { + inputs x = input1[i]; + x.host = true; + input1.push_back(x); + } + return input1; +} + +const std::vector inputs1 = generate_inputs(); + +using SampleRowsTestInt64 = SampleRowsTest; +TEST_P(SampleRowsTestInt64, SamplingTest) { check(); } +INSTANTIATE_TEST_SUITE_P(SampleRowsTests, SampleRowsTestInt64, ::testing::ValuesIn(inputs1)); + +} // namespace matrix +} // namespace raft diff --git a/cpp/test/random/excess_sampling.cu b/cpp/test/random/excess_sampling.cu new file mode 100644 index 0000000000..e86436fb7d --- /dev/null +++ b/cpp/test/random/excess_sampling.cu @@ -0,0 +1,114 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +namespace raft { +namespace random { + +using namespace raft::random; + +struct inputs { + int64_t N; + int64_t n_samples; +}; + +template +::std::ostream& operator<<(::std::ostream& os, const inputs p) +{ + os << p.N << "/" << p.n_samples; + return os; +} + +template +class ExcessSamplingTest : public ::testing::TestWithParam { + public: + ExcessSamplingTest() + : params(::testing::TestWithParam::GetParam()), + stream(resource::get_cuda_stream(res)), + state{137ULL} + { + } + + void check() + { + device_vector out = + raft::random::excess_subsample(res, state, params.N, params.n_samples); + ASSERT_TRUE(out.extent(0) == params.n_samples); + + auto h_out = make_host_vector(res, params.n_samples); + raft::copy(h_out.data_handle(), out.data_handle(), out.size(), stream); + resource::sync_stream(res, stream); + + std::unordered_set occurrence; + int64_t sum = 0; + for (int64_t i = 0; i < params.n_samples; ++i) { + T val = h_out(i); + sum += val; + ASSERT_TRUE(0 <= val && val < params.N) + << "out-of-range index @i=" << i << " val=" << val << " n_samples=" << params.n_samples; + ASSERT_TRUE(occurrence.find(val) == occurrence.end()) + << "repeated index @i=" << i << " idx=" << val; + occurrence.insert(val); + } + float avg = sum / (float)params.n_samples; + if (params.n_samples >= 100 && params.N / params.n_samples < 100) { + ASSERT_TRUE(raft::match(avg, (params.N - 1) / 2.0f, raft::CompareApprox(0.2))) + << "non-uniform sample"; + } + } + + protected: + inputs params; + raft::resources res; + cudaStream_t stream; + RngState state; +}; + +const std::vector input1 = {{1, 0}, + {1, 1}, + {10, 0}, + {10, 1}, + {10, 2}, + {10, 10}, + {137, 42}, + {200, 0}, + {200, 1}, + {200, 100}, + {200, 130}, + {200, 200}, + {10000, 893}, + {10000000000, 1023}}; + +using ExcessSamplingTestInt64 = ExcessSamplingTest; +TEST_P(ExcessSamplingTestInt64, SamplingTest) { check(); } +INSTANTIATE_TEST_SUITE_P(ExcessSamplingTests, ExcessSamplingTestInt64, ::testing::ValuesIn(input1)); + +} // namespace random +} // namespace raft