Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] mdspan-ify rmat_rectangular_gen #833

Merged
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
314 changes: 314 additions & 0 deletions cpp/include/raft/random/detail/rmat_rectangular_generator.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,16 @@

#pragma once

#include <raft/core/device_mdspan.hpp>
#include <raft/core/handle.hpp>
#include <raft/random/rng_device.cuh>
#include <raft/random/rng_state.hpp>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>

#include <optional>
#include <variant>

namespace raft {
namespace random {
namespace detail {
Expand Down Expand Up @@ -182,6 +187,315 @@ void rmat_rectangular_gen_caller(IdxT* out,
r.advance(n_edges, max_scale);
}

/**
* @brief Implementation detail for checking output vector parameter(s)
* of `raft::random::rmat_rectangular_gen`.
*
* @tparam IdxT Type of each node index; must be integral.
*
* Users can provide either `out` by itself, (`out_src` and `out_dst`)
* together, or all three (`out`, `out_src`, and `out_dst`).
* This class prevents users from doing anything other than that.
* It also checks compatibility of dimensions at run time.
*
* The following examples show how to create an output parameter.
*
* @code
* rmat_rectangular_gen_output<size_t> output1(out);
* rmat_rectangular_gen_output<size_t> output2(out_src, out_dst);
* rmat_rectangular_gen_output<size_t> output3(out, out_src, out_dst);
* @endcode
*
*/
template <typename IdxT>
class rmat_rectangular_gen_output {
public:
using out_view_type =
raft::device_mdspan<IdxT, raft::extents<IdxT, raft::dynamic_extent, 2>, raft::row_major>;
using out_src_view_type = raft::device_vector_view<IdxT, IdxT>;
using out_dst_view_type = raft::device_vector_view<IdxT, IdxT>;

private:
class output_pair {
public:
output_pair(const out_src_view_type& src, const out_dst_view_type& dst) : src_(src), dst_(dst)
{
RAFT_EXPECTS(src.extent(0) == dst.extent(0),
"rmat_rectangular_gen: "
"out_src.extent(0) = %zu != out_dst.extent(0) = %zu",
static_cast<std::size_t>(src.extent(0)),
static_cast<std::size_t>(dst.extent(0)));
}

out_src_view_type out_src_view() const { return src_; }

out_dst_view_type out_dst_view() const { return dst_; }

IdxT number_of_edges() const { return src_.extent(0); }

private:
out_src_view_type src_;
out_dst_view_type dst_;
};

class output_triple {
public:
output_triple(const out_view_type& out,
const out_src_view_type& src,
const out_dst_view_type& dst)
: out_(out), pair_(src, dst)
{
RAFT_EXPECTS(out.extent(0) == IdxT(2) * dst.extent(0),
"rmat_rectangular_gen: "
"out.extent(0) = %zu != 2 * out_dst.extent(0) = %zu",
static_cast<std::size_t>(out.extent(0)),
static_cast<std::size_t>(IdxT(2) * dst.extent(0)));
}

out_view_type out_view() const { return out_; }

out_src_view_type out_src_view() const { return pair_.out_src_view(); }

out_dst_view_type out_dst_view() const { return pair_.out_dst_view(); }

IdxT number_of_edges() const { return pair_.number_of_edges(); }

private:
out_view_type out_;
output_pair pair_;
};

public:
/**
* @brief Constructor taking no vectors,
* that effectively makes all the vectors length zero.
*/
rmat_rectangular_gen_output() = default;

/**
* @brief Constructor taking a single vector, that packs the source
* node ids and destination node ids in array-of-structs fashion.
*
* @param[out] out Generated edgelist [on device]. In each row, the
* first element is the source node id, and the second element is
* the destination node id.
*/
rmat_rectangular_gen_output(const out_view_type& out) : data_(out) {}

/**
* @brief Constructor taking two vectors, that store the source node
* ids and the destination node ids separately, in
* struct-of-arrays fashion.
*
* @param[out] out_src Source node id's [on device] [len = n_edges].
*
* @param[out] out_dst Destination node id's [on device] [len = n_edges].
*/
rmat_rectangular_gen_output(const out_src_view_type& src, const out_dst_view_type& dst)
: data_(output_pair(src, dst))
{
}

/**
* @brief Constructor taking all three vectors.
*
* @param[out] out Generated edgelist [on device]. In each row, the
* first element is the source node id, and the second element is
* the destination node id.
*
* @param[out] out_src Source node id's [on device] [len = n_edges].
*
* @param[out] out_dst Destination node id's [on device] [len = n_edges].
*/
rmat_rectangular_gen_output(const out_view_type& out,
const out_src_view_type& src,
const out_dst_view_type& dst)
: data_(output_triple(out, src, dst))
{
}

/**
* @brief Whether this object was created with a constructor
* taking more than zero arguments.
*/
bool has_value() const { return not std::holds_alternative<std::nullopt_t>(data_); }

/**
* @brief Vector for the output single edgelist; the argument given
* to the one-argument constructor, or the first argument of the
* three-argument constructor; `std::nullopt` if not provided.
*/
std::optional<out_view_type> out_view() const
{
if (std::holds_alternative<out_view_type>(data_)) {
return std::get<out_view_type>(data_);
} else if (std::holds_alternative<output_triple>(data_)) {
return std::get<output_triple>(data_).out_view();
} else {
return std::nullopt;
}
}

/**
* @brief Vector for the output source edgelist; the first argument
* given to the two-argument constructor, or the second argument
* of the three-argument constructor; `std::nullopt` if not provided.
*/
std::optional<out_src_view_type> out_src_view() const
{
if (std::holds_alternative<output_pair>(data_)) {
return std::get<output_pair>(data_).out_src_view();
} else if (std::holds_alternative<output_triple>(data_)) {
return std::get<output_triple>(data_).out_src_view();
} else {
return std::nullopt;
}
}

/**
* @brief Vector for the output destination edgelist; the second
* argument given to the two-argument constructor, or the third
* argument of the three-argument constructor;
* `std::nullopt` if not provided.
*/
std::optional<out_dst_view_type> out_dst_view() const
{
if (std::holds_alternative<output_pair>(data_)) {
return std::get<output_pair>(data_).out_dst_view();
} else if (std::holds_alternative<output_triple>(data_)) {
return std::get<output_triple>(data_).out_dst_view();
} else {
return std::nullopt;
}
}

/**
* @brief Number of edges in the graph; zero if no output vector
* was provided to the constructor.
*/
IdxT number_of_edges() const
{
if (std::holds_alternative<out_view_type>(data_)) {
return std::get<out_view_type>(data_).extent(0);
} else if (std::holds_alternative<output_pair>(data_)) {
return std::get<output_pair>(data_).number_of_edges();
} else if (std::holds_alternative<output_triple>(data_)) {
return std::get<output_triple>(data_).number_of_edges();
} else {
return IdxT(0);
}
}

private:
// Defaults to std::nullopt.
std::variant<std::nullopt_t, out_view_type, output_pair, output_triple> data_;
};

/**
* @brief Implementation of `raft::random::rmat_rectangular_gen_impl`.
*
* @tparam IdxT type of each node index
* @tparam ProbT data type used for probability distributions (either fp32 or fp64)
* @param[in] handle RAFT handle, containing the CUDA stream on which to schedule work
* @param[in] r underlying state of the random generator. Especially useful when
* one wants to call this API for multiple times in order to generate
* a larger graph. For that case, just create this object with the
* initial seed once and after every call continue to pass the same
* object for the successive calls.
* @param[out] output Encapsulation of one, two, or three output vectors.
* @param[in] theta distribution of each quadrant at each level of resolution.
* Since these are probabilities, each of the 2x2 matrices for
* each level of the RMAT must sum to one. [on device]
* [dim = max(r_scale, c_scale) x 2 x 2]. Of course, it is assumed
* that each of the group of 2 x 2 numbers all sum up to 1.
* @param[in] r_scale 2^r_scale represents the number of source nodes
* @param[in] c_scale 2^c_scale represents the number of destination nodes
*/
template <typename IdxT, typename ProbT>
void rmat_rectangular_gen_impl(const raft::handle_t& handle,
raft::random::RngState& r,
raft::device_vector_view<const ProbT, IdxT> theta,
raft::random::detail::rmat_rectangular_gen_output<IdxT> output,
IdxT r_scale,
IdxT c_scale)
{
static_assert(std::is_integral_v<IdxT>,
"rmat_rectangular_gen: "
"Template parameter IdxT must be an integral type");
if (not output.has_value()) {
return; // nothing to do; not an error
}

const IdxT expected_theta_len = IdxT(4) * (r_scale >= c_scale ? r_scale : c_scale);
RAFT_EXPECTS(theta.extent(0) == expected_theta_len,
"rmat_rectangular_gen: "
"theta.extent(0) = %zu != 2 * 2 * max(r_scale = %zu, c_scale = %zu) = %zu",
static_cast<std::size_t>(theta.extent(0)),
static_cast<std::size_t>(r_scale),
static_cast<std::size_t>(c_scale),
static_cast<std::size_t>(expected_theta_len));

auto out = output.out_view();
auto out_src = output.out_src_view();
auto out_dst = output.out_dst_view();
const bool out_has_value = out.has_value();
const bool out_src_has_value = out_src.has_value();
const bool out_dst_has_value = out_dst.has_value();
IdxT* out_ptr = out_has_value ? (*out).data_handle() : nullptr;
IdxT* out_src_ptr = out_src_has_value ? (*out_src).data_handle() : nullptr;
IdxT* out_dst_ptr = out_dst_has_value ? (*out_dst).data_handle() : nullptr;
const IdxT n_edges = output.number_of_edges();

rmat_rectangular_gen_caller(out_ptr,
out_src_ptr,
out_dst_ptr,
theta.data_handle(),
r_scale,
c_scale,
n_edges,
handle.get_stream(),
r);
}

/**
* @brief Overload of `rmat_rectangular_gen` that assumes the same
* a, b, c, d probability distributions across all the scales.
*
* `a`, `b, and `c` effectively replace the above overload's
* `theta` parameter.
*/
template <typename IdxT, typename ProbT>
void rmat_rectangular_gen_impl(const raft::handle_t& handle,
raft::random::RngState& r,
raft::random::detail::rmat_rectangular_gen_output<IdxT> output,
ProbT a,
ProbT b,
ProbT c,
IdxT r_scale,
IdxT c_scale)
{
static_assert(std::is_integral_v<IdxT>,
"rmat_rectangular_gen: "
"Template parameter IdxT must be an integral type");
if (not output.has_value()) {
return; // nowhere to write output, so nothing to do
}

auto out = output.out_view();
auto out_src = output.out_src_view();
auto out_dst = output.out_dst_view();
const bool out_has_value = out.has_value();
const bool out_src_has_value = out_src.has_value();
const bool out_dst_has_value = out_dst.has_value();
IdxT* out_ptr = out_has_value ? (*out).data_handle() : nullptr;
IdxT* out_src_ptr = out_src_has_value ? (*out_src).data_handle() : nullptr;
IdxT* out_dst_ptr = out_dst_has_value ? (*out_dst).data_handle() : nullptr;
const IdxT n_edges = output.number_of_edges();

detail::rmat_rectangular_gen_caller(
out_ptr, out_src_ptr, out_dst_ptr, a, b, c, r_scale, c_scale, n_edges, handle.get_stream(), r);
}

} // end namespace detail
} // end namespace random
} // end namespace raft
Loading