diff --git a/cpp/include/raft/random/rmat_rectangular_generator.cuh b/cpp/include/raft/random/rmat_rectangular_generator.cuh index c348ae573e..01bd12cc55 100644 --- a/cpp/include/raft/random/rmat_rectangular_generator.cuh +++ b/cpp/include/raft/random/rmat_rectangular_generator.cuh @@ -19,9 +19,9 @@ #include "detail/rmat_rectangular_generator.cuh" #include -#include #include #include +#include namespace raft::random { @@ -48,77 +48,63 @@ namespace raft::random { */ template class rmat_rectangular_gen_output { -public: + public: using out_view_type = raft::device_mdspan, raft::row_major>; using out_src_view_type = raft::device_vector_view; using out_dst_view_type = raft::device_vector_view; -private: + private: class output_pair { - public: - output_pair(const out_src_view_type& src, - const out_dst_view_type& dst) - : src_(src), dst_(dst) + public: + output_pair(const out_src_view_type& src, const out_dst_view_type& dst) : src_(src), dst_(dst) { - RAFT_EXPECTS(src.extent(0) == dst.extent(0), "rmat_rectangular_gen: " + RAFT_EXPECTS(src.extent(0) == dst.extent(0), + "rmat_rectangular_gen: " "out_src.extent(0) = %d != out_dst.extent(0) = %d", static_cast(src.extent(0)), static_cast(dst.extent(0))); } - out_src_view_type out_src_view() const { - return src_; - } + out_src_view_type out_src_view() const { return src_; } - out_dst_view_type out_dst_view() const { - return dst_; - } + out_dst_view_type out_dst_view() const { return dst_; } - IdxT number_of_edges() const { - return src_.extent(0); - } + IdxT number_of_edges() const { return src_.extent(0); } - private: + private: out_src_view_type src_; out_dst_view_type dst_; }; class output_triple { - public: + public: output_triple(const out_view_type& out, const out_src_view_type& src, const out_dst_view_type& dst) : out_(out), pair_(src, dst) { - RAFT_EXPECTS(out.extent(0) == IdxT(2) * dst.extent(0), "rmat_rectangular_gen: " + RAFT_EXPECTS(out.extent(0) == IdxT(2) * dst.extent(0), + "rmat_rectangular_gen: " "out.extent(0) = %d != 2 * out_dst.extent(0) = %d", static_cast(out.extent(0)), static_cast(IdxT(2) * dst.extent(0))); } - out_view_type out_view() const { - return out_; - } + out_view_type out_view() const { return out_; } - out_src_view_type out_src_view() const { - return pair_.out_src_view(); - } + out_src_view_type out_src_view() const { return pair_.out_src_view(); } - out_dst_view_type out_dst_view() const { - return pair_.out_dst_view(); - } + out_dst_view_type out_dst_view() const { return pair_.out_dst_view(); } - IdxT number_of_edges() const { - return pair_.number_of_edges(); - } + IdxT number_of_edges() const { return pair_.number_of_edges(); } - private: + private: out_view_type out_; output_pair pair_; }; -public: + public: /** * @brief Constructor taking no vectors, * that effectively makes all the vectors length zero. @@ -146,7 +132,8 @@ public: */ rmat_rectangular_gen_output(const out_src_view_type& src, const out_dst_view_type& dst) : data_(output_pair(src, dst)) - {} + { + } /** * @brief Constructor taking all three vectors. @@ -163,22 +150,22 @@ public: const out_src_view_type& src, const out_dst_view_type& dst) : data_(output_triple(out, src, dst)) - {} + { + } /** * @brief Whether this object was created with a constructor * taking more than zero arguments. */ - bool has_value() const { - return not std::holds_alternative(data_); - } + bool has_value() const { return not std::holds_alternative(data_); } /** * @brief Vector for the output single edgelist; the argument given * to the one-argument constructor, or the first argument of the * three-argument constructor; `std::nullopt` if not provided. */ - std::optional out_view() const { + std::optional out_view() const + { if (std::holds_alternative(data_)) { return std::get(data_); } else if (std::holds_alternative(data_)) { @@ -193,7 +180,8 @@ public: * given to the two-argument constructor, or the second argument * of the three-argument constructor; `std::nullopt` if not provided. */ - std::optional out_src_view() const { + std::optional out_src_view() const + { if (std::holds_alternative(data_)) { return std::get(data_).out_src_view(); } else if (std::holds_alternative(data_)) { @@ -209,7 +197,8 @@ public: * argument of the three-argument constructor; * `std::nullopt` if not provided. */ - std::optional out_dst_view() const { + std::optional out_dst_view() const + { if (std::holds_alternative(data_)) { return std::get(data_).out_dst_view(); } else if (std::holds_alternative(data_)) { @@ -223,22 +212,20 @@ public: * @brief Number of edges in the graph; zero if no output vector * was provided to the constructor. */ - IdxT number_of_edges() const { + IdxT number_of_edges() const + { if (std::holds_alternative(data_)) { return std::get(data_).extent(0); - } - else if (std::holds_alternative(data_)) { + } else if (std::holds_alternative(data_)) { return std::get(data_).number_of_edges(); - } - else if (std::holds_alternative(data_)) { + } else if (std::holds_alternative(data_)) { return std::get(data_).number_of_edges(); - } - else { + } else { return IdxT(0); } } -private: + private: // Defaults to std::nullopt. std::variant data_; }; @@ -285,30 +272,32 @@ void rmat_rectangular_gen(const raft::handle_t& handle, IdxT r_scale, IdxT c_scale) { - static_assert(std::is_integral_v, "rmat_rectangular_gen: " + static_assert(std::is_integral_v, + "rmat_rectangular_gen: " "Template parameter IdxT must be an integral type"); if (not output.has_value()) { - return; // nowhere to write output, so nothing to do + return; // nowhere to write output, so nothing to do } const IdxT expected_theta_len = IdxT(4) * (r_scale >= c_scale ? r_scale : c_scale); - RAFT_EXPECTS(theta.extent(0) == expected_theta_len, "rmat_rectangular_gen: " + RAFT_EXPECTS(theta.extent(0) == expected_theta_len, + "rmat_rectangular_gen: " "theta.extent(0) = %d != 2 * 2 * max(r_scale = %d, c_scale = %d) = %d", static_cast(theta.extent(0)), static_cast(r_scale), static_cast(c_scale), static_cast(expected_theta_len)); - auto out = output.out_view(); - auto out_src = output.out_src_view(); - auto out_dst = output.out_dst_view(); - const bool out_has_value = out.has_value(); + auto out = output.out_view(); + auto out_src = output.out_src_view(); + auto out_dst = output.out_dst_view(); + const bool out_has_value = out.has_value(); const bool out_src_has_value = out_src.has_value(); const bool out_dst_has_value = out_dst.has_value(); - IdxT* out_ptr = out_has_value ? (*out).data_handle() : nullptr; - IdxT* out_src_ptr = out_src_has_value ? (*out_src).data_handle() : nullptr; - IdxT* out_dst_ptr = out_dst_has_value ? (*out_dst).data_handle() : nullptr; - const IdxT n_edges = output.number_of_edges(); + IdxT* out_ptr = out_has_value ? (*out).data_handle() : nullptr; + IdxT* out_src_ptr = out_src_has_value ? (*out_src).data_handle() : nullptr; + IdxT* out_dst_ptr = out_dst_has_value ? (*out_dst).data_handle() : nullptr; + const IdxT n_edges = output.number_of_edges(); detail::rmat_rectangular_gen_caller(out_ptr, out_src_ptr, @@ -338,34 +327,26 @@ void rmat_rectangular_gen(const raft::handle_t& handle, IdxT r_scale, IdxT c_scale) { - static_assert(std::is_integral_v, "rmat_rectangular_gen: " + static_assert(std::is_integral_v, + "rmat_rectangular_gen: " "Template parameter IdxT must be an integral type"); if (not output.has_value()) { - return; // nowhere to write output, so nothing to do + return; // nowhere to write output, so nothing to do } - auto out = output.out_view(); - auto out_src = output.out_src_view(); - auto out_dst = output.out_dst_view(); - const bool out_has_value = out.has_value(); + auto out = output.out_view(); + auto out_src = output.out_src_view(); + auto out_dst = output.out_dst_view(); + const bool out_has_value = out.has_value(); const bool out_src_has_value = out_src.has_value(); const bool out_dst_has_value = out_dst.has_value(); - IdxT* out_ptr = out_has_value ? (*out).data_handle() : nullptr; - IdxT* out_src_ptr = out_src_has_value ? (*out_src).data_handle() : nullptr; - IdxT* out_dst_ptr = out_dst_has_value ? (*out_dst).data_handle() : nullptr; - const IdxT n_edges = output.number_of_edges(); + IdxT* out_ptr = out_has_value ? (*out).data_handle() : nullptr; + IdxT* out_src_ptr = out_src_has_value ? (*out_src).data_handle() : nullptr; + IdxT* out_dst_ptr = out_dst_has_value ? (*out_dst).data_handle() : nullptr; + const IdxT n_edges = output.number_of_edges(); - detail::rmat_rectangular_gen_caller(out_ptr, - out_src_ptr, - out_dst_ptr, - a, - b, - c, - r_scale, - c_scale, - n_edges, - handle.get_stream(), - r); + detail::rmat_rectangular_gen_caller( + out_ptr, out_src_ptr, out_dst_ptr, a, b, c, r_scale, c_scale, n_edges, handle.get_stream(), r); } /** diff --git a/cpp/test/random/rmat_rectangular_generator.cu b/cpp/test/random/rmat_rectangular_generator.cu index 4ca5579926..7f25e19c9b 100644 --- a/cpp/test/random/rmat_rectangular_generator.cu +++ b/cpp/test/random/rmat_rectangular_generator.cu @@ -301,21 +301,10 @@ class RmatGenMdspanTest : public ::testing::TestWithParam { if (params.theta_array) { raft::device_vector_view theta_view(theta.data(), theta.size()); - rmat_rectangular_gen(handle, - state, - theta_view, - output, - params.r_scale, - params.c_scale); + rmat_rectangular_gen(handle, state, theta_view, output, params.r_scale, params.c_scale); } else { - rmat_rectangular_gen(handle, - state, - output, - h_theta[0], - h_theta[1], - h_theta[2], - params.r_scale, - params.c_scale); + rmat_rectangular_gen( + handle, state, output, h_theta[0], h_theta[1], h_theta[2], params.r_scale, params.c_scale); } RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); }