diff --git a/cpp/include/raft/random/rmat_rectangular_generator.cuh b/cpp/include/raft/random/rmat_rectangular_generator.cuh index 235c9e1cf8..01bd12cc55 100644 --- a/cpp/include/raft/random/rmat_rectangular_generator.cuh +++ b/cpp/include/raft/random/rmat_rectangular_generator.cuh @@ -18,11 +18,218 @@ #include "detail/rmat_rectangular_generator.cuh" +#include #include #include +#include namespace raft::random { +/** + * @brief Type of the output vector(s) parameter for `rmat_rectangular_gen` + * (see below). + * + * @tparam IdxT Type of each node index; must be integral. + * + * Users can provide either `out` by itself, (`out_src` and `out_dst`) + * together, or all three (`out`, `out_src`, and `out_dst`). + * This class prevents users from doing anything other than that. + * It also checks compatibility of dimensions at run time. + * + * The following examples show how to create an output parameter. + * + * @code + * rmat_rectangular_gen_output output1(out); + * rmat_rectangular_gen_output output2(out_src, out_dst); + * rmat_rectangular_gen_output output3(out, out_src, out_dst); + * @endcode + * + * @{ + */ +template +class rmat_rectangular_gen_output { + public: + using out_view_type = + raft::device_mdspan, raft::row_major>; + using out_src_view_type = raft::device_vector_view; + using out_dst_view_type = raft::device_vector_view; + + private: + class output_pair { + public: + output_pair(const out_src_view_type& src, const out_dst_view_type& dst) : src_(src), dst_(dst) + { + RAFT_EXPECTS(src.extent(0) == dst.extent(0), + "rmat_rectangular_gen: " + "out_src.extent(0) = %d != out_dst.extent(0) = %d", + static_cast(src.extent(0)), + static_cast(dst.extent(0))); + } + + out_src_view_type out_src_view() const { return src_; } + + out_dst_view_type out_dst_view() const { return dst_; } + + IdxT number_of_edges() const { return src_.extent(0); } + + private: + out_src_view_type src_; + out_dst_view_type dst_; + }; + + class output_triple { + public: + output_triple(const out_view_type& out, + const out_src_view_type& src, + const out_dst_view_type& dst) + : out_(out), pair_(src, dst) + { + RAFT_EXPECTS(out.extent(0) == IdxT(2) * dst.extent(0), + "rmat_rectangular_gen: " + "out.extent(0) = %d != 2 * out_dst.extent(0) = %d", + static_cast(out.extent(0)), + static_cast(IdxT(2) * dst.extent(0))); + } + + out_view_type out_view() const { return out_; } + + out_src_view_type out_src_view() const { return pair_.out_src_view(); } + + out_dst_view_type out_dst_view() const { return pair_.out_dst_view(); } + + IdxT number_of_edges() const { return pair_.number_of_edges(); } + + private: + out_view_type out_; + output_pair pair_; + }; + + public: + /** + * @brief Constructor taking no vectors, + * that effectively makes all the vectors length zero. + */ + rmat_rectangular_gen_output() = default; + + /** + * @brief Constructor taking a single vector, that packs the source + * node ids and destination node ids in array-of-structs fashion. + * + * @param[out] out Generated edgelist [on device]. In each row, the + * first element is the source node id, and the second element is + * the destination node id. + */ + rmat_rectangular_gen_output(const out_view_type& out) : data_(out) {} + + /** + * @brief Constructor taking two vectors, that store the source node + * ids and the destination node ids separately, in + * struct-of-arrays fashion. + * + * @param[out] out_src Source node id's [on device] [len = n_edges]. + * + * @param[out] out_dst Destination node id's [on device] [len = n_edges]. + */ + rmat_rectangular_gen_output(const out_src_view_type& src, const out_dst_view_type& dst) + : data_(output_pair(src, dst)) + { + } + + /** + * @brief Constructor taking all three vectors. + * + * @param[out] out Generated edgelist [on device]. In each row, the + * first element is the source node id, and the second element is + * the destination node id. + * + * @param[out] out_src Source node id's [on device] [len = n_edges]. + * + * @param[out] out_dst Destination node id's [on device] [len = n_edges]. + */ + rmat_rectangular_gen_output(const out_view_type& out, + const out_src_view_type& src, + const out_dst_view_type& dst) + : data_(output_triple(out, src, dst)) + { + } + + /** + * @brief Whether this object was created with a constructor + * taking more than zero arguments. + */ + bool has_value() const { return not std::holds_alternative(data_); } + + /** + * @brief Vector for the output single edgelist; the argument given + * to the one-argument constructor, or the first argument of the + * three-argument constructor; `std::nullopt` if not provided. + */ + std::optional out_view() const + { + if (std::holds_alternative(data_)) { + return std::get(data_); + } else if (std::holds_alternative(data_)) { + return std::get(data_).out_view(); + } else { + return std::nullopt; + } + } + + /** + * @brief Vector for the output source edgelist; the first argument + * given to the two-argument constructor, or the second argument + * of the three-argument constructor; `std::nullopt` if not provided. + */ + std::optional out_src_view() const + { + if (std::holds_alternative(data_)) { + return std::get(data_).out_src_view(); + } else if (std::holds_alternative(data_)) { + return std::get(data_).out_src_view(); + } else { + return std::nullopt; + } + } + + /** + * @brief Vector for the output destination edgelist; the second + * argument given to the two-argument constructor, or the third + * argument of the three-argument constructor; + * `std::nullopt` if not provided. + */ + std::optional out_dst_view() const + { + if (std::holds_alternative(data_)) { + return std::get(data_).out_dst_view(); + } else if (std::holds_alternative(data_)) { + return std::get(data_).out_dst_view(); + } else { + return std::nullopt; + } + } + + /** + * @brief Number of edges in the graph; zero if no output vector + * was provided to the constructor. + */ + IdxT number_of_edges() const + { + if (std::holds_alternative(data_)) { + return std::get(data_).extent(0); + } else if (std::holds_alternative(data_)) { + return std::get(data_).number_of_edges(); + } else if (std::holds_alternative(data_)) { + return std::get(data_).number_of_edges(); + } else { + return IdxT(0); + } + } + + private: + // Defaults to std::nullopt. + std::variant data_; +}; + /** * @brief Generate RMAT for a rectangular adjacency matrix (useful when * graphs to be generated are bipartite) @@ -36,22 +243,14 @@ namespace raft::random { * a larger graph. For that case, just create this object with the * initial seed once and after every call continue to pass the same * object for the successive calls. - * @param[out] out generated edgelist [on device] [dim = n_edges x 2]. In each row - * the first element is the source node id, and the second element - * is the destination node id. If you don't need this output - * then pass a `nullptr` in its place. - * @param[out] out_src list of source node id's [on device] [len = n_edges]. If you - * don't need this output then pass a `nullptr` in its place. - * @param[out] out_dst list of destination node id's [on device] [len = n_edges]. If - * you don't need this output then pass a `nullptr` in its place. * @param[in] theta distribution of each quadrant at each level of resolution. * Since these are probabilities, each of the 2x2 matrices for * each level of the RMAT must sum to one. [on device] * [dim = max(r_scale, c_scale) x 2 x 2]. Of course, it is assumed * that each of the group of 2 x 2 numbers all sum up to 1. + * @param[out] output generated edgelist [on device] * @param[in] r_scale 2^r_scale represents the number of source nodes * @param[in] c_scale 2^c_scale represents the number of destination nodes - * @param[in] n_edges number of edges to generate * * We call the `r_scale != c_scale` case the "rectangular adjacency matrix" case (IOW generating * bipartite graphs). In this case, at `depth >= r_scale`, the distribution is assumed to be: @@ -64,23 +263,45 @@ namespace raft::random { * @note This also only generates directed graphs. If undirected graphs are needed, then a * separate post-processing step is expected to be done by the caller. - * - * @{ */ template void rmat_rectangular_gen(const raft::handle_t& handle, raft::random::RngState& r, raft::device_vector_view theta, - raft::device_vector_view out, - raft::device_vector_view out_src, - raft::device_vector_view out_dst, + rmat_rectangular_gen_output output, IdxT r_scale, - IdxT c_scale, - IdxT n_edges) + IdxT c_scale) { - detail::rmat_rectangular_gen_caller(out.data_handle(), - out_src.data_handle(), - out_dst.data_handle(), + static_assert(std::is_integral_v, + "rmat_rectangular_gen: " + "Template parameter IdxT must be an integral type"); + if (not output.has_value()) { + return; // nowhere to write output, so nothing to do + } + + const IdxT expected_theta_len = IdxT(4) * (r_scale >= c_scale ? r_scale : c_scale); + RAFT_EXPECTS(theta.extent(0) == expected_theta_len, + "rmat_rectangular_gen: " + "theta.extent(0) = %d != 2 * 2 * max(r_scale = %d, c_scale = %d) = %d", + static_cast(theta.extent(0)), + static_cast(r_scale), + static_cast(c_scale), + static_cast(expected_theta_len)); + + auto out = output.out_view(); + auto out_src = output.out_src_view(); + auto out_dst = output.out_dst_view(); + const bool out_has_value = out.has_value(); + const bool out_src_has_value = out_src.has_value(); + const bool out_dst_has_value = out_dst.has_value(); + IdxT* out_ptr = out_has_value ? (*out).data_handle() : nullptr; + IdxT* out_src_ptr = out_src_has_value ? (*out_src).data_handle() : nullptr; + IdxT* out_dst_ptr = out_dst_has_value ? (*out_dst).data_handle() : nullptr; + const IdxT n_edges = output.number_of_edges(); + + detail::rmat_rectangular_gen_caller(out_ptr, + out_src_ptr, + out_dst_ptr, theta.data_handle(), r_scale, c_scale, @@ -89,6 +310,45 @@ void rmat_rectangular_gen(const raft::handle_t& handle, r); } +/** + * @brief Overload of `rmat_rectangular_gen` that assumes the same + * a, b, c, d probability distributions across all the scales. + * + * `a`, `b, and `c` effectively replace the above overload's + * `theta` parameter. + */ +template +void rmat_rectangular_gen(const raft::handle_t& handle, + raft::random::RngState& r, + rmat_rectangular_gen_output output, + ProbT a, + ProbT b, + ProbT c, + IdxT r_scale, + IdxT c_scale) +{ + static_assert(std::is_integral_v, + "rmat_rectangular_gen: " + "Template parameter IdxT must be an integral type"); + if (not output.has_value()) { + return; // nowhere to write output, so nothing to do + } + + auto out = output.out_view(); + auto out_src = output.out_src_view(); + auto out_dst = output.out_dst_view(); + const bool out_has_value = out.has_value(); + const bool out_src_has_value = out_src.has_value(); + const bool out_dst_has_value = out_dst.has_value(); + IdxT* out_ptr = out_has_value ? (*out).data_handle() : nullptr; + IdxT* out_src_ptr = out_src_has_value ? (*out_src).data_handle() : nullptr; + IdxT* out_dst_ptr = out_dst_has_value ? (*out_dst).data_handle() : nullptr; + const IdxT n_edges = output.number_of_edges(); + + detail::rmat_rectangular_gen_caller( + out_ptr, out_src_ptr, out_dst_ptr, a, b, c, r_scale, c_scale, n_edges, handle.get_stream(), r); +} + /** * @brief Legacy overload of `rmat_rectangular_gen` * taking raw arrays instead of mdspan. @@ -134,36 +394,6 @@ void rmat_rectangular_gen(IdxT* out, out, out_src, out_dst, theta, r_scale, c_scale, n_edges, stream, r); } -/** - * This is the same as the previous method but assumes the same a, b, c, d probability - * distributions across all the scales - */ -template -void rmat_rectangular_gen(const raft::handle_t& handle, - raft::random::RngState& r, - raft::device_vector_view out, - raft::device_vector_view out_src, - raft::device_vector_view out_dst, - ProbT a, - ProbT b, - ProbT c, - IdxT r_scale, - IdxT c_scale, - IdxT n_edges) -{ - detail::rmat_rectangular_gen_caller(out.data_handle(), - out_src.data_handle(), - out_dst.data_handle(), - a, - b, - c, - r_scale, - c_scale, - n_edges, - handle.get_stream(), - r); -} - /** * @brief Legacy overload of `rmat_rectangular_gen` * taking raw arrays instead of mdspan. diff --git a/cpp/test/random/rmat_rectangular_generator.cu b/cpp/test/random/rmat_rectangular_generator.cu index a58a20a05e..7f25e19c9b 100644 --- a/cpp/test/random/rmat_rectangular_generator.cu +++ b/cpp/test/random/rmat_rectangular_generator.cu @@ -287,33 +287,24 @@ class RmatGenMdspanTest : public ::testing::TestWithParam { void SetUp() override { using index_type = size_t; - raft::device_vector_view out_view(out.data(), out.size()); - raft::device_vector_view out_src_view(out_src.data(), out_src.size()); - raft::device_vector_view out_dst_view(out_dst.data(), out_dst.size()); + + using out_view_type = typename rmat_rectangular_gen_output::out_view_type; + out_view_type out_view(out.data(), out.size()); + + using out_src_view_type = typename rmat_rectangular_gen_output::out_src_view_type; + out_src_view_type out_src_view(out_src.data(), out_src.size()); + + using out_dst_view_type = typename rmat_rectangular_gen_output::out_dst_view_type; + out_dst_view_type out_dst_view(out_dst.data(), out_dst.size()); + + rmat_rectangular_gen_output output(out_view, out_src_view, out_dst_view); if (params.theta_array) { raft::device_vector_view theta_view(theta.data(), theta.size()); - rmat_rectangular_gen(handle, - state, - theta_view, - out_view, - out_src_view, - out_dst_view, - params.r_scale, - params.c_scale, - params.n_edges); + rmat_rectangular_gen(handle, state, theta_view, output, params.r_scale, params.c_scale); } else { - rmat_rectangular_gen(handle, - state, - out_view, - out_src_view, - out_dst_view, - h_theta[0], - h_theta[1], - h_theta[2], - params.r_scale, - params.c_scale, - params.n_edges); + rmat_rectangular_gen( + handle, state, output, h_theta[0], h_theta[1], h_theta[2], params.r_scale, params.c_scale); } RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); }