Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] mdspan-ify rmat_rectangular_gen #833

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions cpp/include/raft/random/detail/rmat_rectangular_generator.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

#pragma once

#include "rmat_rectangular_generator_types.cuh"

#include <raft/core/handle.hpp>
#include <raft/random/rng_device.cuh>
#include <raft/random/rng_state.hpp>
#include <raft/util/cuda_utils.cuh>
Expand Down Expand Up @@ -182,6 +185,111 @@ void rmat_rectangular_gen_caller(IdxT* out,
r.advance(n_edges, max_scale);
}

/**
* @brief Implementation of `raft::random::rmat_rectangular_gen_impl`.
*
* @tparam IdxT type of each node index
* @tparam ProbT data type used for probability distributions (either fp32 or fp64)
* @param[in] handle RAFT handle, containing the CUDA stream on which to schedule work
* @param[in] r underlying state of the random generator. Especially useful when
* one wants to call this API for multiple times in order to generate
* a larger graph. For that case, just create this object with the
* initial seed once and after every call continue to pass the same
* object for the successive calls.
* @param[out] output Encapsulation of one, two, or three output vectors.
* @param[in] theta distribution of each quadrant at each level of resolution.
* Since these are probabilities, each of the 2x2 matrices for
* each level of the RMAT must sum to one. [on device]
* [dim = max(r_scale, c_scale) x 2 x 2]. Of course, it is assumed
* that each of the group of 2 x 2 numbers all sum up to 1.
* @param[in] r_scale 2^r_scale represents the number of source nodes
* @param[in] c_scale 2^c_scale represents the number of destination nodes
*/
template <typename IdxT, typename ProbT>
void rmat_rectangular_gen_impl(const raft::handle_t& handle,
raft::random::RngState& r,
raft::device_vector_view<const ProbT, IdxT> theta,
raft::random::detail::rmat_rectangular_gen_output<IdxT> output,
IdxT r_scale,
IdxT c_scale)
{
static_assert(std::is_integral_v<IdxT>,
"rmat_rectangular_gen: "
"Template parameter IdxT must be an integral type");
if (output.empty()) {
return; // nothing to do; not an error
}

const IdxT expected_theta_len = IdxT(4) * (r_scale >= c_scale ? r_scale : c_scale);
RAFT_EXPECTS(theta.extent(0) == expected_theta_len,
"rmat_rectangular_gen: "
"theta.extent(0) = %zu != 2 * 2 * max(r_scale = %zu, c_scale = %zu) = %zu",
static_cast<std::size_t>(theta.extent(0)),
static_cast<std::size_t>(r_scale),
static_cast<std::size_t>(c_scale),
static_cast<std::size_t>(expected_theta_len));

auto out = output.out_view();
auto out_src = output.out_src_view();
auto out_dst = output.out_dst_view();
const bool out_has_value = out.has_value();
const bool out_src_has_value = out_src.has_value();
const bool out_dst_has_value = out_dst.has_value();
IdxT* out_ptr = out_has_value ? (*out).data_handle() : nullptr;
IdxT* out_src_ptr = out_src_has_value ? (*out_src).data_handle() : nullptr;
IdxT* out_dst_ptr = out_dst_has_value ? (*out_dst).data_handle() : nullptr;
const IdxT n_edges = output.number_of_edges();

rmat_rectangular_gen_caller(out_ptr,
out_src_ptr,
out_dst_ptr,
theta.data_handle(),
r_scale,
c_scale,
n_edges,
handle.get_stream(),
r);
}

/**
* @brief Overload of `rmat_rectangular_gen` that assumes the same
* a, b, c, d probability distributions across all the scales.
*
* `a`, `b, and `c` effectively replace the above overload's
* `theta` parameter.
*/
template <typename IdxT, typename ProbT>
void rmat_rectangular_gen_impl(const raft::handle_t& handle,
raft::random::RngState& r,
raft::random::detail::rmat_rectangular_gen_output<IdxT> output,
ProbT a,
ProbT b,
ProbT c,
IdxT r_scale,
IdxT c_scale)
{
static_assert(std::is_integral_v<IdxT>,
"rmat_rectangular_gen: "
"Template parameter IdxT must be an integral type");
if (output.empty()) {
return; // nothing to do; not an error
}

auto out = output.out_view();
auto out_src = output.out_src_view();
auto out_dst = output.out_dst_view();
const bool out_has_value = out.has_value();
const bool out_src_has_value = out_src.has_value();
const bool out_dst_has_value = out_dst.has_value();
IdxT* out_ptr = out_has_value ? (*out).data_handle() : nullptr;
IdxT* out_src_ptr = out_src_has_value ? (*out_src).data_handle() : nullptr;
IdxT* out_dst_ptr = out_dst_has_value ? (*out_dst).data_handle() : nullptr;
const IdxT n_edges = output.number_of_edges();

detail::rmat_rectangular_gen_caller(
out_ptr, out_src_ptr, out_dst_ptr, a, b, c, r_scale, c_scale, n_edges, handle.get_stream(), r);
}

} // end namespace detail
} // end namespace random
} // end namespace raft
259 changes: 259 additions & 0 deletions cpp/include/raft/random/detail/rmat_rectangular_generator_types.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/core/device_mdspan.hpp>
#include <raft/random/rng_device.cuh>
#include <raft/random/rng_state.hpp>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>

#include <optional>
#include <variant>

namespace raft {
namespace random {
namespace detail {

/**
* @brief Implementation detail for checking output vector parameter(s)
* of `raft::random::rmat_rectangular_gen`.
*
* `raft::random::rmat_rectangular_gen` lets users specify
* output vector(s) in three different ways.
*
* 1. One vector: `out`, an "array-of-structs" representation
* of the edge list.
*
* 2. Two vectors: `out_src` and `out_dst`, together forming
* a "struct of arrays" representation of the edge list.
*
* 3. Three vectors: `out`, `out_src`, and `out_dst`.
* `out` is as in (1),
* and `out_src` and `out_dst` are as in (2).
*
* This class prevents users from doing anything other than that,
* and makes it easier for the three cases to share a common implementation.
* It also prevents duplication of run-time vector length checking
* (`out` must have twice the number of elements as `out_src` and `out_dst`,
* and `out_src` and `out_dst` must have the same length).
*
* @tparam IdxT Type of each node index; must be integral.
*
* The following examples show how to create an output parameter.
*
* @code
* rmat_rectangular_gen_output<IdxT> output1(out);
* rmat_rectangular_gen_output<IdxT> output2(out_src, out_dst);
* rmat_rectangular_gen_output<IdxT> output3(out, out_src, out_dst);
* @endcode
*/
template <typename IdxT>
class rmat_rectangular_gen_output {
public:
using out_view_type =
raft::device_mdspan<IdxT, raft::extents<IdxT, raft::dynamic_extent, 2>, raft::row_major>;
using out_src_view_type = raft::device_vector_view<IdxT, IdxT>;
using out_dst_view_type = raft::device_vector_view<IdxT, IdxT>;

private:
class output_pair {
public:
output_pair(const out_src_view_type& src, const out_dst_view_type& dst) : src_(src), dst_(dst)
{
RAFT_EXPECTS(src.extent(0) == dst.extent(0),
"rmat_rectangular_gen: "
"out_src.extent(0) = %zu != out_dst.extent(0) = %zu",
static_cast<std::size_t>(src.extent(0)),
static_cast<std::size_t>(dst.extent(0)));
}

out_src_view_type out_src_view() const { return src_; }

out_dst_view_type out_dst_view() const { return dst_; }

IdxT number_of_edges() const { return src_.extent(0); }

bool empty() const { return src_.extent(0) == 0 && dst_.extent(0) == 0; }

private:
out_src_view_type src_;
out_dst_view_type dst_;
};

class output_triple {
public:
output_triple(const out_view_type& out,
const out_src_view_type& src,
const out_dst_view_type& dst)
: out_(out), pair_(src, dst)
{
RAFT_EXPECTS(out.extent(0) == IdxT(2) * dst.extent(0),
"rmat_rectangular_gen: "
"out.extent(0) = %zu != 2 * out_dst.extent(0) = %zu",
static_cast<std::size_t>(out.extent(0)),
static_cast<std::size_t>(IdxT(2) * dst.extent(0)));
}

out_view_type out_view() const { return out_; }

out_src_view_type out_src_view() const { return pair_.out_src_view(); }

out_dst_view_type out_dst_view() const { return pair_.out_dst_view(); }

IdxT number_of_edges() const { return pair_.number_of_edges(); }

bool empty() const { return out_.extent(0) == 0 && pair_.empty(); }

private:
out_view_type out_;
output_pair pair_;
};

public:
/**
* @brief You're not allowed to construct this with no vectors.
*/
rmat_rectangular_gen_output() = delete;

/**
* @brief Constructor taking a single vector, that packs the source
* node ids and destination node ids in array-of-structs fashion.
*
* @param[out] out Generated edgelist [on device]. In each row, the
* first element is the source node id, and the second element is
* the destination node id.
*/
rmat_rectangular_gen_output(const out_view_type& out) : data_(out) {}

/**
* @brief Constructor taking two vectors, that store the source node
* ids and the destination node ids separately, in
* struct-of-arrays fashion.
*
* @param[out] out_src Source node id's [on device] [len = n_edges].
*
* @param[out] out_dst Destination node id's [on device] [len = n_edges].
*/
rmat_rectangular_gen_output(const out_src_view_type& src, const out_dst_view_type& dst)
: data_(output_pair(src, dst))
{
}

/**
* @brief Constructor taking all three vectors.
*
* @param[out] out Generated edgelist [on device]. In each row, the
* first element is the source node id, and the second element is
* the destination node id.
*
* @param[out] out_src Source node id's [on device] [len = n_edges].
*
* @param[out] out_dst Destination node id's [on device] [len = n_edges].
*/
rmat_rectangular_gen_output(const out_view_type& out,
const out_src_view_type& src,
const out_dst_view_type& dst)
: data_(output_triple(out, src, dst))
{
}

/**
* @brief Whether the vector(s) are all length zero.
*/
bool empty() const
{
if (std::holds_alternative<out_view_type>(data_)) {
return std::get<out_view_type>(data_).extent(0) == 0;
} else if (std::holds_alternative<output_pair>(data_)) {
return std::get<output_pair>(data_).empty();
} else { // std::holds_alternative<output_triple>(data_)
return std::get<output_triple>(data_).empty();
}
}

/**
* @brief Vector for the output single edgelist; the argument given
* to the one-argument constructor, or the first argument of the
* three-argument constructor; `std::nullopt` if not provided.
*/
std::optional<out_view_type> out_view() const
{
if (std::holds_alternative<out_view_type>(data_)) {
return std::get<out_view_type>(data_);
} else if (std::holds_alternative<output_triple>(data_)) {
return std::get<output_triple>(data_).out_view();
} else { // if (std::holds_alternative<>(output_pair))
return std::nullopt;
}
}

/**
* @brief Vector for the output source edgelist; the first argument
* given to the two-argument constructor, or the second argument
* of the three-argument constructor; `std::nullopt` if not provided.
*/
std::optional<out_src_view_type> out_src_view() const
{
if (std::holds_alternative<output_pair>(data_)) {
return std::get<output_pair>(data_).out_src_view();
} else if (std::holds_alternative<output_triple>(data_)) {
return std::get<output_triple>(data_).out_src_view();
} else { // if (std::holds_alternative<out_view_type>(data_))
return std::nullopt;
}
}

/**
* @brief Vector for the output destination edgelist; the second
* argument given to the two-argument constructor, or the third
* argument of the three-argument constructor;
* `std::nullopt` if not provided.
*/
std::optional<out_dst_view_type> out_dst_view() const
{
if (std::holds_alternative<output_pair>(data_)) {
return std::get<output_pair>(data_).out_dst_view();
} else if (std::holds_alternative<output_triple>(data_)) {
return std::get<output_triple>(data_).out_dst_view();
} else { // if (std::holds_alternative<out_view_type>(data_))
return std::nullopt;
}
}

/**
* @brief Number of edges in the graph; zero if no output vector
* was provided to the constructor.
*/
IdxT number_of_edges() const
{
if (std::holds_alternative<out_view_type>(data_)) {
return std::get<out_view_type>(data_).extent(0);
} else if (std::holds_alternative<output_pair>(data_)) {
return std::get<output_pair>(data_).number_of_edges();
} else { // if (std::holds_alternative<output_triple>(data_))
return std::get<output_triple>(data_).number_of_edges();
}
}

private:
std::variant<out_view_type, output_pair, output_triple> data_;
};

} // end namespace detail
} // end namespace random
} // end namespace raft
Loading