Skip to content

Commit

Permalink
Fix formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
mhoemmen committed Sep 23, 2022
1 parent 8d39780 commit f63415f
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 96 deletions.
145 changes: 63 additions & 82 deletions cpp/include/raft/random/rmat_rectangular_generator.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
#include "detail/rmat_rectangular_generator.cuh"

#include <optional>
#include <variant>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/handle.hpp>
#include <variant>

namespace raft::random {

Expand All @@ -48,77 +48,63 @@ namespace raft::random {
*/
template <typename IdxT>
class rmat_rectangular_gen_output {
public:
public:
using out_view_type =
raft::device_mdspan<IdxT, raft::extents<IdxT, raft::dynamic_extent, 2>, raft::row_major>;
using out_src_view_type = raft::device_vector_view<IdxT, IdxT>;
using out_dst_view_type = raft::device_vector_view<IdxT, IdxT>;

private:
private:
class output_pair {
public:
output_pair(const out_src_view_type& src,
const out_dst_view_type& dst)
: src_(src), dst_(dst)
public:
output_pair(const out_src_view_type& src, const out_dst_view_type& dst) : src_(src), dst_(dst)
{
RAFT_EXPECTS(src.extent(0) == dst.extent(0), "rmat_rectangular_gen: "
RAFT_EXPECTS(src.extent(0) == dst.extent(0),
"rmat_rectangular_gen: "
"out_src.extent(0) = %d != out_dst.extent(0) = %d",
static_cast<int>(src.extent(0)),
static_cast<int>(dst.extent(0)));
}

out_src_view_type out_src_view() const {
return src_;
}
out_src_view_type out_src_view() const { return src_; }

out_dst_view_type out_dst_view() const {
return dst_;
}
out_dst_view_type out_dst_view() const { return dst_; }

IdxT number_of_edges() const {
return src_.extent(0);
}
IdxT number_of_edges() const { return src_.extent(0); }

private:
private:
out_src_view_type src_;
out_dst_view_type dst_;
};

class output_triple {
public:
public:
output_triple(const out_view_type& out,
const out_src_view_type& src,
const out_dst_view_type& dst)
: out_(out), pair_(src, dst)
{
RAFT_EXPECTS(out.extent(0) == IdxT(2) * dst.extent(0), "rmat_rectangular_gen: "
RAFT_EXPECTS(out.extent(0) == IdxT(2) * dst.extent(0),
"rmat_rectangular_gen: "
"out.extent(0) = %d != 2 * out_dst.extent(0) = %d",
static_cast<int>(out.extent(0)),
static_cast<int>(IdxT(2) * dst.extent(0)));
}

out_view_type out_view() const {
return out_;
}
out_view_type out_view() const { return out_; }

out_src_view_type out_src_view() const {
return pair_.out_src_view();
}
out_src_view_type out_src_view() const { return pair_.out_src_view(); }

out_dst_view_type out_dst_view() const {
return pair_.out_dst_view();
}
out_dst_view_type out_dst_view() const { return pair_.out_dst_view(); }

IdxT number_of_edges() const {
return pair_.number_of_edges();
}
IdxT number_of_edges() const { return pair_.number_of_edges(); }

private:
private:
out_view_type out_;
output_pair pair_;
};

public:
public:
/**
* @brief Constructor taking no vectors,
* that effectively makes all the vectors length zero.
Expand Down Expand Up @@ -146,7 +132,8 @@ public:
*/
rmat_rectangular_gen_output(const out_src_view_type& src, const out_dst_view_type& dst)
: data_(output_pair(src, dst))
{}
{
}

/**
* @brief Constructor taking all three vectors.
Expand All @@ -163,22 +150,22 @@ public:
const out_src_view_type& src,
const out_dst_view_type& dst)
: data_(output_triple(out, src, dst))
{}
{
}

/**
* @brief Whether this object was created with a constructor
* taking more than zero arguments.
*/
bool has_value() const {
return not std::holds_alternative<std::nullopt_t>(data_);
}
bool has_value() const { return not std::holds_alternative<std::nullopt_t>(data_); }

/**
* @brief Vector for the output single edgelist; the argument given
* to the one-argument constructor, or the first argument of the
* three-argument constructor; `std::nullopt` if not provided.
*/
std::optional<out_view_type> out_view() const {
std::optional<out_view_type> out_view() const
{
if (std::holds_alternative<out_view_type>(data_)) {
return std::get<out_view_type>(data_);
} else if (std::holds_alternative<output_triple>(data_)) {
Expand All @@ -193,7 +180,8 @@ public:
* given to the two-argument constructor, or the second argument
* of the three-argument constructor; `std::nullopt` if not provided.
*/
std::optional<out_src_view_type> out_src_view() const {
std::optional<out_src_view_type> out_src_view() const
{
if (std::holds_alternative<output_pair>(data_)) {
return std::get<output_pair>(data_).out_src_view();
} else if (std::holds_alternative<output_triple>(data_)) {
Expand All @@ -209,7 +197,8 @@ public:
* argument of the three-argument constructor;
* `std::nullopt` if not provided.
*/
std::optional<out_dst_view_type> out_dst_view() const {
std::optional<out_dst_view_type> out_dst_view() const
{
if (std::holds_alternative<output_pair>(data_)) {
return std::get<output_pair>(data_).out_dst_view();
} else if (std::holds_alternative<output_triple>(data_)) {
Expand All @@ -223,22 +212,20 @@ public:
* @brief Number of edges in the graph; zero if no output vector
* was provided to the constructor.
*/
IdxT number_of_edges() const {
IdxT number_of_edges() const
{
if (std::holds_alternative<out_view_type>(data_)) {
return std::get<out_view_type>(data_).extent(0);
}
else if (std::holds_alternative<output_pair>(data_)) {
} else if (std::holds_alternative<output_pair>(data_)) {
return std::get<output_pair>(data_).number_of_edges();
}
else if (std::holds_alternative<output_triple>(data_)) {
} else if (std::holds_alternative<output_triple>(data_)) {
return std::get<output_triple>(data_).number_of_edges();
}
else {
} else {
return IdxT(0);
}
}

private:
private:
// Defaults to std::nullopt.
std::variant<std::nullopt_t, out_view_type, output_pair, output_triple> data_;
};
Expand Down Expand Up @@ -285,30 +272,32 @@ void rmat_rectangular_gen(const raft::handle_t& handle,
IdxT r_scale,
IdxT c_scale)
{
static_assert(std::is_integral_v<IdxT>, "rmat_rectangular_gen: "
static_assert(std::is_integral_v<IdxT>,
"rmat_rectangular_gen: "
"Template parameter IdxT must be an integral type");
if (not output.has_value()) {
return; // nowhere to write output, so nothing to do
return; // nowhere to write output, so nothing to do
}

const IdxT expected_theta_len = IdxT(4) * (r_scale >= c_scale ? r_scale : c_scale);
RAFT_EXPECTS(theta.extent(0) == expected_theta_len, "rmat_rectangular_gen: "
RAFT_EXPECTS(theta.extent(0) == expected_theta_len,
"rmat_rectangular_gen: "
"theta.extent(0) = %d != 2 * 2 * max(r_scale = %d, c_scale = %d) = %d",
static_cast<int>(theta.extent(0)),
static_cast<int>(r_scale),
static_cast<int>(c_scale),
static_cast<int>(expected_theta_len));

auto out = output.out_view();
auto out_src = output.out_src_view();
auto out_dst = output.out_dst_view();
const bool out_has_value = out.has_value();
auto out = output.out_view();
auto out_src = output.out_src_view();
auto out_dst = output.out_dst_view();
const bool out_has_value = out.has_value();
const bool out_src_has_value = out_src.has_value();
const bool out_dst_has_value = out_dst.has_value();
IdxT* out_ptr = out_has_value ? (*out).data_handle() : nullptr;
IdxT* out_src_ptr = out_src_has_value ? (*out_src).data_handle() : nullptr;
IdxT* out_dst_ptr = out_dst_has_value ? (*out_dst).data_handle() : nullptr;
const IdxT n_edges = output.number_of_edges();
IdxT* out_ptr = out_has_value ? (*out).data_handle() : nullptr;
IdxT* out_src_ptr = out_src_has_value ? (*out_src).data_handle() : nullptr;
IdxT* out_dst_ptr = out_dst_has_value ? (*out_dst).data_handle() : nullptr;
const IdxT n_edges = output.number_of_edges();

detail::rmat_rectangular_gen_caller(out_ptr,
out_src_ptr,
Expand Down Expand Up @@ -338,34 +327,26 @@ void rmat_rectangular_gen(const raft::handle_t& handle,
IdxT r_scale,
IdxT c_scale)
{
static_assert(std::is_integral_v<IdxT>, "rmat_rectangular_gen: "
static_assert(std::is_integral_v<IdxT>,
"rmat_rectangular_gen: "
"Template parameter IdxT must be an integral type");
if (not output.has_value()) {
return; // nowhere to write output, so nothing to do
return; // nowhere to write output, so nothing to do
}

auto out = output.out_view();
auto out_src = output.out_src_view();
auto out_dst = output.out_dst_view();
const bool out_has_value = out.has_value();
auto out = output.out_view();
auto out_src = output.out_src_view();
auto out_dst = output.out_dst_view();
const bool out_has_value = out.has_value();
const bool out_src_has_value = out_src.has_value();
const bool out_dst_has_value = out_dst.has_value();
IdxT* out_ptr = out_has_value ? (*out).data_handle() : nullptr;
IdxT* out_src_ptr = out_src_has_value ? (*out_src).data_handle() : nullptr;
IdxT* out_dst_ptr = out_dst_has_value ? (*out_dst).data_handle() : nullptr;
const IdxT n_edges = output.number_of_edges();
IdxT* out_ptr = out_has_value ? (*out).data_handle() : nullptr;
IdxT* out_src_ptr = out_src_has_value ? (*out_src).data_handle() : nullptr;
IdxT* out_dst_ptr = out_dst_has_value ? (*out_dst).data_handle() : nullptr;
const IdxT n_edges = output.number_of_edges();

detail::rmat_rectangular_gen_caller(out_ptr,
out_src_ptr,
out_dst_ptr,
a,
b,
c,
r_scale,
c_scale,
n_edges,
handle.get_stream(),
r);
detail::rmat_rectangular_gen_caller(
out_ptr, out_src_ptr, out_dst_ptr, a, b, c, r_scale, c_scale, n_edges, handle.get_stream(), r);
}

/**
Expand Down
17 changes: 3 additions & 14 deletions cpp/test/random/rmat_rectangular_generator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -301,21 +301,10 @@ class RmatGenMdspanTest : public ::testing::TestWithParam<RmatInputs> {

if (params.theta_array) {
raft::device_vector_view<const float, index_type> theta_view(theta.data(), theta.size());
rmat_rectangular_gen(handle,
state,
theta_view,
output,
params.r_scale,
params.c_scale);
rmat_rectangular_gen(handle, state, theta_view, output, params.r_scale, params.c_scale);
} else {
rmat_rectangular_gen(handle,
state,
output,
h_theta[0],
h_theta[1],
h_theta[2],
params.r_scale,
params.c_scale);
rmat_rectangular_gen(
handle, state, output, h_theta[0], h_theta[1], h_theta[2], params.r_scale, params.c_scale);
}
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
}
Expand Down

0 comments on commit f63415f

Please sign in to comment.