diff --git a/cpp/include/raft/random/detail/rmat_rectangular_generator.cuh b/cpp/include/raft/random/detail/rmat_rectangular_generator.cuh index ddb7214a1a..5ce7e909ee 100644 --- a/cpp/include/raft/random/detail/rmat_rectangular_generator.cuh +++ b/cpp/include/raft/random/detail/rmat_rectangular_generator.cuh @@ -16,6 +16,9 @@ #pragma once +#include "rmat_rectangular_generator_types.cuh" + +#include #include #include #include @@ -182,6 +185,111 @@ void rmat_rectangular_gen_caller(IdxT* out, r.advance(n_edges, max_scale); } +/** + * @brief Implementation of `raft::random::rmat_rectangular_gen_impl`. + * + * @tparam IdxT type of each node index + * @tparam ProbT data type used for probability distributions (either fp32 or fp64) + * @param[in] handle RAFT handle, containing the CUDA stream on which to schedule work + * @param[in] r underlying state of the random generator. Especially useful when + * one wants to call this API for multiple times in order to generate + * a larger graph. For that case, just create this object with the + * initial seed once and after every call continue to pass the same + * object for the successive calls. + * @param[out] output Encapsulation of one, two, or three output vectors. + * @param[in] theta distribution of each quadrant at each level of resolution. + * Since these are probabilities, each of the 2x2 matrices for + * each level of the RMAT must sum to one. [on device] + * [dim = max(r_scale, c_scale) x 2 x 2]. Of course, it is assumed + * that each of the group of 2 x 2 numbers all sum up to 1. + * @param[in] r_scale 2^r_scale represents the number of source nodes + * @param[in] c_scale 2^c_scale represents the number of destination nodes + */ +template +void rmat_rectangular_gen_impl(const raft::handle_t& handle, + raft::random::RngState& r, + raft::device_vector_view theta, + raft::random::detail::rmat_rectangular_gen_output output, + IdxT r_scale, + IdxT c_scale) +{ + static_assert(std::is_integral_v, + "rmat_rectangular_gen: " + "Template parameter IdxT must be an integral type"); + if (output.empty()) { + return; // nothing to do; not an error + } + + const IdxT expected_theta_len = IdxT(4) * (r_scale >= c_scale ? r_scale : c_scale); + RAFT_EXPECTS(theta.extent(0) == expected_theta_len, + "rmat_rectangular_gen: " + "theta.extent(0) = %zu != 2 * 2 * max(r_scale = %zu, c_scale = %zu) = %zu", + static_cast(theta.extent(0)), + static_cast(r_scale), + static_cast(c_scale), + static_cast(expected_theta_len)); + + auto out = output.out_view(); + auto out_src = output.out_src_view(); + auto out_dst = output.out_dst_view(); + const bool out_has_value = out.has_value(); + const bool out_src_has_value = out_src.has_value(); + const bool out_dst_has_value = out_dst.has_value(); + IdxT* out_ptr = out_has_value ? (*out).data_handle() : nullptr; + IdxT* out_src_ptr = out_src_has_value ? (*out_src).data_handle() : nullptr; + IdxT* out_dst_ptr = out_dst_has_value ? (*out_dst).data_handle() : nullptr; + const IdxT n_edges = output.number_of_edges(); + + rmat_rectangular_gen_caller(out_ptr, + out_src_ptr, + out_dst_ptr, + theta.data_handle(), + r_scale, + c_scale, + n_edges, + handle.get_stream(), + r); +} + +/** + * @brief Overload of `rmat_rectangular_gen` that assumes the same + * a, b, c, d probability distributions across all the scales. + * + * `a`, `b, and `c` effectively replace the above overload's + * `theta` parameter. + */ +template +void rmat_rectangular_gen_impl(const raft::handle_t& handle, + raft::random::RngState& r, + raft::random::detail::rmat_rectangular_gen_output output, + ProbT a, + ProbT b, + ProbT c, + IdxT r_scale, + IdxT c_scale) +{ + static_assert(std::is_integral_v, + "rmat_rectangular_gen: " + "Template parameter IdxT must be an integral type"); + if (output.empty()) { + return; // nothing to do; not an error + } + + auto out = output.out_view(); + auto out_src = output.out_src_view(); + auto out_dst = output.out_dst_view(); + const bool out_has_value = out.has_value(); + const bool out_src_has_value = out_src.has_value(); + const bool out_dst_has_value = out_dst.has_value(); + IdxT* out_ptr = out_has_value ? (*out).data_handle() : nullptr; + IdxT* out_src_ptr = out_src_has_value ? (*out_src).data_handle() : nullptr; + IdxT* out_dst_ptr = out_dst_has_value ? (*out_dst).data_handle() : nullptr; + const IdxT n_edges = output.number_of_edges(); + + detail::rmat_rectangular_gen_caller( + out_ptr, out_src_ptr, out_dst_ptr, a, b, c, r_scale, c_scale, n_edges, handle.get_stream(), r); +} + } // end namespace detail } // end namespace random } // end namespace raft diff --git a/cpp/include/raft/random/detail/rmat_rectangular_generator_types.cuh b/cpp/include/raft/random/detail/rmat_rectangular_generator_types.cuh new file mode 100644 index 0000000000..daf3392f3d --- /dev/null +++ b/cpp/include/raft/random/detail/rmat_rectangular_generator_types.cuh @@ -0,0 +1,259 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include + +namespace raft { +namespace random { +namespace detail { + +/** + * @brief Implementation detail for checking output vector parameter(s) + * of `raft::random::rmat_rectangular_gen`. + * + * `raft::random::rmat_rectangular_gen` lets users specify + * output vector(s) in three different ways. + * + * 1. One vector: `out`, an "array-of-structs" representation + * of the edge list. + * + * 2. Two vectors: `out_src` and `out_dst`, together forming + * a "struct of arrays" representation of the edge list. + * + * 3. Three vectors: `out`, `out_src`, and `out_dst`. + * `out` is as in (1), + * and `out_src` and `out_dst` are as in (2). + * + * This class prevents users from doing anything other than that, + * and makes it easier for the three cases to share a common implementation. + * It also prevents duplication of run-time vector length checking + * (`out` must have twice the number of elements as `out_src` and `out_dst`, + * and `out_src` and `out_dst` must have the same length). + * + * @tparam IdxT Type of each node index; must be integral. + * + * The following examples show how to create an output parameter. + * + * @code + * rmat_rectangular_gen_output output1(out); + * rmat_rectangular_gen_output output2(out_src, out_dst); + * rmat_rectangular_gen_output output3(out, out_src, out_dst); + * @endcode + */ +template +class rmat_rectangular_gen_output { + public: + using out_view_type = + raft::device_mdspan, raft::row_major>; + using out_src_view_type = raft::device_vector_view; + using out_dst_view_type = raft::device_vector_view; + + private: + class output_pair { + public: + output_pair(const out_src_view_type& src, const out_dst_view_type& dst) : src_(src), dst_(dst) + { + RAFT_EXPECTS(src.extent(0) == dst.extent(0), + "rmat_rectangular_gen: " + "out_src.extent(0) = %zu != out_dst.extent(0) = %zu", + static_cast(src.extent(0)), + static_cast(dst.extent(0))); + } + + out_src_view_type out_src_view() const { return src_; } + + out_dst_view_type out_dst_view() const { return dst_; } + + IdxT number_of_edges() const { return src_.extent(0); } + + bool empty() const { return src_.extent(0) == 0 && dst_.extent(0) == 0; } + + private: + out_src_view_type src_; + out_dst_view_type dst_; + }; + + class output_triple { + public: + output_triple(const out_view_type& out, + const out_src_view_type& src, + const out_dst_view_type& dst) + : out_(out), pair_(src, dst) + { + RAFT_EXPECTS(out.extent(0) == IdxT(2) * dst.extent(0), + "rmat_rectangular_gen: " + "out.extent(0) = %zu != 2 * out_dst.extent(0) = %zu", + static_cast(out.extent(0)), + static_cast(IdxT(2) * dst.extent(0))); + } + + out_view_type out_view() const { return out_; } + + out_src_view_type out_src_view() const { return pair_.out_src_view(); } + + out_dst_view_type out_dst_view() const { return pair_.out_dst_view(); } + + IdxT number_of_edges() const { return pair_.number_of_edges(); } + + bool empty() const { return out_.extent(0) == 0 && pair_.empty(); } + + private: + out_view_type out_; + output_pair pair_; + }; + + public: + /** + * @brief You're not allowed to construct this with no vectors. + */ + rmat_rectangular_gen_output() = delete; + + /** + * @brief Constructor taking a single vector, that packs the source + * node ids and destination node ids in array-of-structs fashion. + * + * @param[out] out Generated edgelist [on device]. In each row, the + * first element is the source node id, and the second element is + * the destination node id. + */ + rmat_rectangular_gen_output(const out_view_type& out) : data_(out) {} + + /** + * @brief Constructor taking two vectors, that store the source node + * ids and the destination node ids separately, in + * struct-of-arrays fashion. + * + * @param[out] out_src Source node id's [on device] [len = n_edges]. + * + * @param[out] out_dst Destination node id's [on device] [len = n_edges]. + */ + rmat_rectangular_gen_output(const out_src_view_type& src, const out_dst_view_type& dst) + : data_(output_pair(src, dst)) + { + } + + /** + * @brief Constructor taking all three vectors. + * + * @param[out] out Generated edgelist [on device]. In each row, the + * first element is the source node id, and the second element is + * the destination node id. + * + * @param[out] out_src Source node id's [on device] [len = n_edges]. + * + * @param[out] out_dst Destination node id's [on device] [len = n_edges]. + */ + rmat_rectangular_gen_output(const out_view_type& out, + const out_src_view_type& src, + const out_dst_view_type& dst) + : data_(output_triple(out, src, dst)) + { + } + + /** + * @brief Whether the vector(s) are all length zero. + */ + bool empty() const + { + if (std::holds_alternative(data_)) { + return std::get(data_).extent(0) == 0; + } else if (std::holds_alternative(data_)) { + return std::get(data_).empty(); + } else { // std::holds_alternative(data_) + return std::get(data_).empty(); + } + } + + /** + * @brief Vector for the output single edgelist; the argument given + * to the one-argument constructor, or the first argument of the + * three-argument constructor; `std::nullopt` if not provided. + */ + std::optional out_view() const + { + if (std::holds_alternative(data_)) { + return std::get(data_); + } else if (std::holds_alternative(data_)) { + return std::get(data_).out_view(); + } else { // if (std::holds_alternative<>(output_pair)) + return std::nullopt; + } + } + + /** + * @brief Vector for the output source edgelist; the first argument + * given to the two-argument constructor, or the second argument + * of the three-argument constructor; `std::nullopt` if not provided. + */ + std::optional out_src_view() const + { + if (std::holds_alternative(data_)) { + return std::get(data_).out_src_view(); + } else if (std::holds_alternative(data_)) { + return std::get(data_).out_src_view(); + } else { // if (std::holds_alternative(data_)) + return std::nullopt; + } + } + + /** + * @brief Vector for the output destination edgelist; the second + * argument given to the two-argument constructor, or the third + * argument of the three-argument constructor; + * `std::nullopt` if not provided. + */ + std::optional out_dst_view() const + { + if (std::holds_alternative(data_)) { + return std::get(data_).out_dst_view(); + } else if (std::holds_alternative(data_)) { + return std::get(data_).out_dst_view(); + } else { // if (std::holds_alternative(data_)) + return std::nullopt; + } + } + + /** + * @brief Number of edges in the graph; zero if no output vector + * was provided to the constructor. + */ + IdxT number_of_edges() const + { + if (std::holds_alternative(data_)) { + return std::get(data_).extent(0); + } else if (std::holds_alternative(data_)) { + return std::get(data_).number_of_edges(); + } else { // if (std::holds_alternative(data_)) + return std::get(data_).number_of_edges(); + } + } + + private: + std::variant data_; +}; + +} // end namespace detail +} // end namespace random +} // end namespace raft diff --git a/cpp/include/raft/random/rmat_rectangular_generator.cuh b/cpp/include/raft/random/rmat_rectangular_generator.cuh index aad1cf0c88..cedcca1711 100644 --- a/cpp/include/raft/random/rmat_rectangular_generator.cuh +++ b/cpp/include/raft/random/rmat_rectangular_generator.cuh @@ -21,48 +21,226 @@ namespace raft::random { /** - * @brief Generate RMAT for a rectangular shaped adjacency matrices (useful when - * graphs to be generated are bipartite) + * @brief Generate a bipartite RMAT graph for a rectangular adjacency matrix. * - * @tparam IdxT node indices type + * This is the most general of several overloads of `rmat_rectangular_gen` + * in this file, and thus has the most detailed documentation. + * + * @tparam IdxT Type of each node index + * @tparam ProbT Data type used for probability distributions (either fp32 or fp64) + * + * @param[in] handle RAFT handle, containing the CUDA stream on which to schedule work + * @param[in] r underlying state of the random generator. Especially useful when + * one wants to call this API for multiple times in order to generate + * a larger graph. For that case, just create this object with the + * initial seed once and after every call continue to pass the same + * object for the successive calls. + * @param[out] out Generated edgelist [on device], packed in array-of-structs fashion. + * In each row, the first element is the source node id, + * and the second element is the destination node id. + * @param[out] out_src Source node id's [on device]. + * @param[out] out_dst Destination node id's [on device]. `out_src` and `out_dst` + * together form the struct-of-arrays representation of the same + * output data as `out`. + * @param[in] theta distribution of each quadrant at each level of resolution. + * Since these are probabilities, each of the 2x2 matrices for + * each level of the RMAT must sum to one. [on device] + * [dim = max(r_scale, c_scale) x 2 x 2]. Of course, it is assumed + * that each of the group of 2 x 2 numbers all sum up to 1. + * @param[in] r_scale 2^r_scale represents the number of source nodes + * @param[in] c_scale 2^c_scale represents the number of destination nodes + * + * @pre `out.extent(0) == 2 * `out_src.extent(0)` is `true` + * @pre `out_src.extent(0) == out_dst.extent(0)` is `true` + * + * We call the `r_scale != c_scale` case the "rectangular adjacency matrix" case + * (in other words, generating bipartite graphs). In this case, at `depth >= r_scale`, + * the distribution is assumed to be: + * + * `[theta[4 * depth] + theta[4 * depth + 2], theta[4 * depth + 1] + theta[4 * depth + 3]; 0, 0]`. + * + * Then for `depth >= c_scale`, the distribution is assumed to be: + * + * `[theta[4 * depth] + theta[4 * depth + 1], 0; theta[4 * depth + 2] + theta[4 * depth + 3], 0]`. + * + * @note This can generate duplicate edges and self-loops. It is the responsibility of the + * caller to clean them up accordingly. + * + * @note This also only generates directed graphs. If undirected graphs are needed, then a + * separate post-processing step is expected to be done by the caller. + * + * @{ + */ +template +void rmat_rectangular_gen( + const raft::handle_t& handle, + raft::random::RngState& r, + raft::device_vector_view theta, + raft::device_mdspan, raft::row_major> out, + raft::device_vector_view out_src, + raft::device_vector_view out_dst, + IdxT r_scale, + IdxT c_scale) +{ + detail::rmat_rectangular_gen_output output(out, out_src, out_dst); + detail::rmat_rectangular_gen_impl(handle, r, theta, output, r_scale, c_scale); +} + +/** + * @brief Overload of `rmat_rectangular_gen` that only generates + * the struct-of-arrays (two vectors) output representation. + * + * This overload only generates the struct-of-arrays (two vectors) + * output representation: output vector `out_src` of source node id's, + * and output vector `out_dst` of destination node id's. + * + * @pre `out_src.extent(0) == out_dst.extent(0)` is `true` + */ +template +void rmat_rectangular_gen(const raft::handle_t& handle, + raft::random::RngState& r, + raft::device_vector_view theta, + raft::device_vector_view out_src, + raft::device_vector_view out_dst, + IdxT r_scale, + IdxT c_scale) +{ + detail::rmat_rectangular_gen_output output(out_src, out_dst); + detail::rmat_rectangular_gen_impl(handle, r, theta, output, r_scale, c_scale); +} + +/** + * @brief Overload of `rmat_rectangular_gen` that only generates + * the array-of-structs (one vector) output representation. + * + * This overload only generates the array-of-structs (one vector) + * output representation: a single output vector `out`, + * where in each row, the first element is the source node id, + * and the second element is the destination node id. + */ +template +void rmat_rectangular_gen( + const raft::handle_t& handle, + raft::random::RngState& r, + raft::device_vector_view theta, + raft::device_mdspan, raft::row_major> out, + IdxT r_scale, + IdxT c_scale) +{ + detail::rmat_rectangular_gen_output output(out); + detail::rmat_rectangular_gen_impl(handle, r, theta, output, r_scale, c_scale); +} + +/** + * @brief Overload of `rmat_rectangular_gen` that assumes the same + * a, b, c, d probability distributions across all the scales, + * and takes all three output vectors + * (`out` with the array-of-structs output representation, + * and `out_src` and `out_dst` with the struct-of-arrays + * output representation). + * + * `a`, `b, and `c` effectively replace the above overloads' + * `theta` parameter. + * + * @pre `out.extent(0) == 2 * `out_src.extent(0)` is `true` + * @pre `out_src.extent(0) == out_dst.extent(0)` is `true` + */ +template +void rmat_rectangular_gen( + const raft::handle_t& handle, + raft::random::RngState& r, + raft::device_mdspan, raft::row_major> out, + raft::device_vector_view out_src, + raft::device_vector_view out_dst, + ProbT a, + ProbT b, + ProbT c, + IdxT r_scale, + IdxT c_scale) +{ + detail::rmat_rectangular_gen_output output(out, out_src, out_dst); + detail::rmat_rectangular_gen_impl(handle, r, output, a, b, c, r_scale, c_scale); +} + +/** + * @brief Overload of `rmat_rectangular_gen` that assumes the same + * a, b, c, d probability distributions across all the scales, + * and takes only two output vectors + * (the struct-of-arrays output representation). + * + * `a`, `b, and `c` effectively replace the above overloads' + * `theta` parameter. + * + * @pre `out_src.extent(0) == out_dst.extent(0)` is `true` + */ +template +void rmat_rectangular_gen(const raft::handle_t& handle, + raft::random::RngState& r, + raft::device_vector_view out_src, + raft::device_vector_view out_dst, + ProbT a, + ProbT b, + ProbT c, + IdxT r_scale, + IdxT c_scale) +{ + detail::rmat_rectangular_gen_output output(out_src, out_dst); + detail::rmat_rectangular_gen_impl(handle, r, output, a, b, c, r_scale, c_scale); +} + +/** + * @brief Overload of `rmat_rectangular_gen` that assumes the same + * a, b, c, d probability distributions across all the scales, + * and takes only one output vector + * (the array-of-structs output representation). + * + * `a`, `b, and `c` effectively replace the above overloads' + * `theta` parameter. + */ +template +void rmat_rectangular_gen( + const raft::handle_t& handle, + raft::random::RngState& r, + raft::device_mdspan, raft::row_major> out, + ProbT a, + ProbT b, + ProbT c, + IdxT r_scale, + IdxT c_scale) +{ + detail::rmat_rectangular_gen_output output(out); + detail::rmat_rectangular_gen_impl(handle, r, output, a, b, c, r_scale, c_scale); +} + +/** + * @brief Legacy overload of `rmat_rectangular_gen` + * taking raw arrays instead of mdspan. + * + * @tparam IdxT type of each node index * @tparam ProbT data type used for probability distributions (either fp32 or fp64) * - * @param[out] out generated edgelist [on device] [dim = n_edges x 2]. On each row - * the first element corresponds to the source node id while the - * second, the destination node id. If you don't need this output + * @param[out] out generated edgelist [on device] [dim = n_edges x 2]. In each row + * the first element is the source node id, and the second element + * is the destination node id. If you don't need this output * then pass a `nullptr` in its place. * @param[out] out_src list of source node id's [on device] [len = n_edges]. If you * don't need this output then pass a `nullptr` in its place. * @param[out] out_dst list of destination node id's [on device] [len = n_edges]. If * you don't need this output then pass a `nullptr` in its place. * @param[in] theta distribution of each quadrant at each level of resolution. - * Since these are probabilities, each of the 2x2 matrix for + * Since these are probabilities, each of the 2x2 matrices for * each level of the RMAT must sum to one. [on device] * [dim = max(r_scale, c_scale) x 2 x 2]. Of course, it is assumed * that each of the group of 2 x 2 numbers all sum up to 1. * @param[in] r_scale 2^r_scale represents the number of source nodes * @param[in] c_scale 2^c_scale represents the number of destination nodes * @param[in] n_edges number of edges to generate - * @param[in] stream cuda stream to schedule the work on + * @param[in] stream cuda stream on which to schedule the work * @param[in] r underlying state of the random generator. Especially useful when * one wants to call this API for multiple times in order to generate * a larger graph. For that case, just create this object with the * initial seed once and after every call continue to pass the same * object for the successive calls. - * - * When `r_scale != c_scale` it is referred to as rectangular adjacency matrix case (IOW generating - * bipartite graphs). In this case, at `depth >= r_scale`, the distribution is assumed to be: - * `[theta[4 * depth] + theta[4 * depth + 2], theta[4 * depth + 1] + theta[4 * depth + 3]; 0, 0]`. - * Then for the `depth >= c_scale`, the distribution is assumed to be: - * `[theta[4 * depth] + theta[4 * depth + 1], 0; theta[4 * depth + 2] + theta[4 * depth + 3], 0]`. - * - * @note This can generate duplicate edges and self-loops. It is the responsibility of the - * caller to clean them up accordingly. - - * @note This also only generates directed graphs. If undirected graphs are needed, then a - * separate post-processing step is expected to be done by the caller. - * - * @{ */ template void rmat_rectangular_gen(IdxT* out, @@ -80,8 +258,10 @@ void rmat_rectangular_gen(IdxT* out, } /** - * This is the same as the previous method but assumes the same a, b, c, d probability - * distributions across all the scales + * @brief Legacy overload of `rmat_rectangular_gen` + * taking raw arrays instead of mdspan. + * This overload assumes the same a, b, c, d probability distributions + * across all the scales. */ template void rmat_rectangular_gen(IdxT* out, @@ -99,6 +279,7 @@ void rmat_rectangular_gen(IdxT* out, detail::rmat_rectangular_gen_caller( out, out_src, out_dst, a, b, c, r_scale, c_scale, n_edges, stream, r); } + /** @} */ } // end namespace raft::random diff --git a/cpp/include/raft/random/rng_state.hpp b/cpp/include/raft/random/rng_state.hpp index 44372902b1..ec15ef286f 100644 --- a/cpp/include/raft/random/rng_state.hpp +++ b/cpp/include/raft/random/rng_state.hpp @@ -19,6 +19,8 @@ #pragma once +#include + namespace raft { namespace random { diff --git a/cpp/test/nvtx.cpp b/cpp/test/nvtx.cpp index 81f692a215..635fe55012 100644 --- a/cpp/test/nvtx.cpp +++ b/cpp/test/nvtx.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,7 +15,7 @@ */ #ifdef NVTX_ENABLED #include -#include +#include /** * tests for the functionality of generating next color based on string * entered in the NVTX Range marker wrappers diff --git a/cpp/test/random/rmat_rectangular_generator.cu b/cpp/test/random/rmat_rectangular_generator.cu index 194f89dd65..0baaaf28cf 100644 --- a/cpp/test/random/rmat_rectangular_generator.cu +++ b/cpp/test/random/rmat_rectangular_generator.cu @@ -253,6 +253,111 @@ class RmatGenTest : public ::testing::TestWithParam { size_t max_scale; }; +class RmatGenMdspanTest : public ::testing::TestWithParam { + public: + RmatGenMdspanTest() + : handle{}, + stream{handle.get_stream()}, + params{::testing::TestWithParam::GetParam()}, + out{params.n_edges * 2, stream}, + out_src{params.n_edges, stream}, + out_dst{params.n_edges, stream}, + theta{0, stream}, + h_theta{}, + state{params.seed, GeneratorType::GenPC}, + max_scale{std::max(params.r_scale, params.c_scale)} + { + theta.resize(4 * max_scale, stream); + uniform(state, theta.data(), theta.size(), 0.0f, 1.0f, stream); + normalize(theta.data(), + theta.data(), + max_scale, + params.r_scale, + params.c_scale, + params.r_scale != params.c_scale, + params.theta_array, + stream); + h_theta.resize(theta.size()); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::update_host(h_theta.data(), theta.data(), theta.size(), stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + } + + protected: + void SetUp() override + { + using index_type = size_t; + + using out_view_type = raft::device_mdspan, + raft::row_major>; + out_view_type out_view(out.data(), out.size()); + + using out_src_view_type = raft::device_vector_view; + out_src_view_type out_src_view(out_src.data(), out_src.size()); + + using out_dst_view_type = raft::device_vector_view; + out_dst_view_type out_dst_view(out_dst.data(), out_dst.size()); + + if (params.theta_array) { + raft::device_vector_view theta_view(theta.data(), theta.size()); + rmat_rectangular_gen(handle, + state, + theta_view, + out_view, + out_src_view, + out_dst_view, + params.r_scale, + params.c_scale); + } else { + rmat_rectangular_gen(handle, + state, + out_view, + out_src_view, + out_dst_view, + h_theta[0], + h_theta[1], + h_theta[2], + params.r_scale, + params.c_scale); + } + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + } + + void validate() + { + rmm::device_uvector hist{theta.size(), stream}; + RAFT_CUDA_TRY(cudaMemsetAsync(hist.data(), 0, hist.size() * sizeof(int), stream)); + compute_hist<<(out.size() / 2, 256), 256, 0, stream>>>( + hist.data(), out.data(), out.size(), max_scale, params.r_scale, params.c_scale); + RAFT_CUDA_TRY(cudaGetLastError()); + rmm::device_uvector computed_theta{theta.size(), stream}; + normalize(computed_theta.data(), + hist.data(), + max_scale, + params.r_scale, + params.c_scale, + false, + true, + stream); + RAFT_CUDA_TRY(cudaGetLastError()); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + ASSERT_TRUE(devArrMatchHost( + h_theta.data(), computed_theta.data(), theta.size(), CompareApprox(params.eps))); + } + + protected: + raft::handle_t handle; + cudaStream_t stream; + + RmatInputs params; + rmm::device_uvector out, out_src, out_dst; + rmm::device_uvector theta; + std::vector h_theta; + RngState state; + size_t max_scale; +}; + static const float TOLERANCE = 0.01f; const std::vector inputs = { @@ -287,5 +392,8 @@ const std::vector inputs = { TEST_P(RmatGenTest, Result) { validate(); } INSTANTIATE_TEST_SUITE_P(RmatGenTests, RmatGenTest, ::testing::ValuesIn(inputs)); +TEST_P(RmatGenMdspanTest, Result) { validate(); } +INSTANTIATE_TEST_SUITE_P(RmatGenMdspanTests, RmatGenMdspanTest, ::testing::ValuesIn(inputs)); + } // namespace random } // namespace raft