Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

elementwise_min|max reduction op #3341

Merged
merged 2 commits into from
Mar 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ void per_v_transform_reduce_e(raft::handle_t const& handle,
ReduceOp reduce_op,
VertexValueOutputIterator vertex_value_output_first)
{
static_assert(ReduceOp::pure_function || reduce_op::has_compatible_raft_comms_op_v<ReduceOp> ||
static_assert(ReduceOp::pure_function && reduce_op::has_compatible_raft_comms_op_v<ReduceOp> &&
reduce_op::has_identity_element_v<ReduceOp>); // current restriction, to support
// general reduction, we may need to
// take a less efficient code path
Expand Down
48 changes: 26 additions & 22 deletions cpp/src/prims/property_op_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -123,60 +123,62 @@ struct atomic_add_thrust_tuple_impl<Iterator, TupleType, I, I> {
};

template <typename T>
__device__ std::enable_if_t<std::is_arithmetic<T>::value, void> atomic_min_impl(
__device__ std::enable_if_t<std::is_arithmetic<T>::value, void> elementwise_atomic_min_impl(
thrust::detail::any_assign& /* dereferencing thrust::discard_iterator results in this type */ lhs,
T const& rhs)
{
// no-op
}

template <typename T>
__device__ std::enable_if_t<std::is_arithmetic<T>::value, void> atomic_min_impl(T& lhs,
T const& rhs)
__device__ std::enable_if_t<std::is_arithmetic<T>::value, void> elementwise_atomic_min_impl(
T& lhs, T const& rhs)
{
atomicMin(&lhs, rhs);
}

template <typename Iterator, typename TupleType, size_t I, size_t N>
struct atomic_min_thrust_tuple_impl {
struct elementwise_atomic_min_thrust_tuple_impl {
__device__ constexpr void compute(Iterator iter, TupleType const& value) const
{
atomic_min_impl(thrust::raw_reference_cast(thrust::get<I>(*iter)), thrust::get<I>(value));
atomic_min_thrust_tuple_impl<Iterator, TupleType, I + 1, N>().compute(iter, value);
elementwise_atomic_min_impl(thrust::raw_reference_cast(thrust::get<I>(*iter)),
thrust::get<I>(value));
elementwise_atomic_min_thrust_tuple_impl<Iterator, TupleType, I + 1, N>().compute(iter, value);
}
};

template <typename Iterator, typename TupleType, size_t I>
struct atomic_min_thrust_tuple_impl<Iterator, TupleType, I, I> {
struct elementwise_atomic_min_thrust_tuple_impl<Iterator, TupleType, I, I> {
__device__ constexpr void compute(Iterator iter, TupleType const& value) const {}
};

template <typename T>
__device__ std::enable_if_t<std::is_arithmetic<T>::value, void> atomic_max_impl(
__device__ std::enable_if_t<std::is_arithmetic<T>::value, void> elementwise_atomic_max_impl(
thrust::detail::any_assign& /* dereferencing thrust::discard_iterator results in this type */ lhs,
T const& rhs)
{
// no-op
}

template <typename T>
__device__ std::enable_if_t<std::is_arithmetic<T>::value, void> atomic_max_impl(T& lhs,
T const& rhs)
__device__ std::enable_if_t<std::is_arithmetic<T>::value, void> elementwise_atomic_max_impl(
T& lhs, T const& rhs)
{
atomicMax(&lhs, rhs);
}

template <typename Iterator, typename TupleType, size_t I, size_t N>
struct atomic_max_thrust_tuple_impl {
struct elementwise_atomic_max_thrust_tuple_impl {
__device__ constexpr void compute(Iterator iter, TupleType const& value) const
{
atomic_max_impl(thrust::raw_reference_cast(thrust::get<I>(*iter)), thrust::get<I>(value));
atomic_max_thrust_tuple_impl<Iterator, TupleType, I + 1, N>().compute(iter, value);
elementwise_atomic_max_impl(thrust::raw_reference_cast(thrust::get<I>(*iter)),
thrust::get<I>(value));
elementwise_atomic_max_thrust_tuple_impl<Iterator, TupleType, I + 1, N>().compute(iter, value);
}
};

template <typename Iterator, typename TupleType, size_t I>
struct atomic_max_thrust_tuple_impl<Iterator, TupleType, I, I> {
struct elementwise_atomic_max_thrust_tuple_impl<Iterator, TupleType, I, I> {
__device__ constexpr void compute(Iterator iter, TupleType const& value) const {}
};

Expand Down Expand Up @@ -292,7 +294,7 @@ __device__

template <typename Iterator, typename T>
__device__ std::enable_if_t<thrust::detail::is_discard_iterator<Iterator>::value, void>
atomic_min_edge_op_result(Iterator iter, T const& value)
elementwise_atomic_min_edge_op_result(Iterator iter, T const& value)
{
// no-op
}
Expand All @@ -302,7 +304,7 @@ __device__
std::enable_if_t<std::is_same<typename thrust::iterator_traits<Iterator>::value_type, T>::value &&
std::is_arithmetic<T>::value,
void>
atomic_min_edge_op_result(Iterator iter, T const& value)
elementwise_atomic_min_edge_op_result(Iterator iter, T const& value)
{
atomicMin(&(thrust::raw_reference_cast(*iter)), value);
}
Expand All @@ -312,17 +314,18 @@ __device__
std::enable_if_t<is_thrust_tuple<typename thrust::iterator_traits<Iterator>::value_type>::value &&
is_thrust_tuple<T>::value,
void>
atomic_min_edge_op_result(Iterator iter, T const& value)
elementwise_atomic_min_edge_op_result(Iterator iter, T const& value)
{
static_assert(thrust::tuple_size<typename thrust::iterator_traits<Iterator>::value_type>::value ==
thrust::tuple_size<T>::value);
size_t constexpr tuple_size = thrust::tuple_size<T>::value;
detail::atomic_min_thrust_tuple_impl<Iterator, T, size_t{0}, tuple_size>().compute(iter, value);
detail::elementwise_atomic_min_thrust_tuple_impl<Iterator, T, size_t{0}, tuple_size>().compute(
iter, value);
}

template <typename Iterator, typename T>
__device__ std::enable_if_t<thrust::detail::is_discard_iterator<Iterator>::value, void>
atomic_max_edge_op_result(Iterator iter, T const& value)
elementwise_atomic_max_edge_op_result(Iterator iter, T const& value)
{
// no-op
}
Expand All @@ -332,7 +335,7 @@ __device__
std::enable_if_t<std::is_same<typename thrust::iterator_traits<Iterator>::value_type, T>::value &&
std::is_arithmetic<T>::value,
void>
atomic_max_edge_op_result(Iterator iter, T const& value)
elementwise_atomic_max_edge_op_result(Iterator iter, T const& value)
{
atomicMax(&(thrust::raw_reference_cast(*iter)), value);
}
Expand All @@ -342,12 +345,13 @@ __device__
std::enable_if_t<is_thrust_tuple<typename thrust::iterator_traits<Iterator>::value_type>::value &&
is_thrust_tuple<T>::value,
void>
atomic_max_edge_op_result(Iterator iter, T const& value)
elementwise_atomic_max_edge_op_result(Iterator iter, T const& value)
{
static_assert(thrust::tuple_size<typename thrust::iterator_traits<Iterator>::value_type>::value ==
thrust::tuple_size<T>::value);
size_t constexpr tuple_size = thrust::tuple_size<T>::value;
detail::atomic_max_thrust_tuple_impl<Iterator, T, size_t{0}, tuple_size>().compute(iter, value);
detail::elementwise_atomic_max_thrust_tuple_impl<Iterator, T, size_t{0}, tuple_size>().compute(
iter, value);
}

} // namespace cugraph
136 changes: 126 additions & 10 deletions cpp/src/prims/reduce_op.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,37 @@

#include <prims/property_op_utils.cuh>

#include <cugraph/utilities/thrust_tuple_utils.hpp>

#include <raft/core/comms.hpp>

#include <thrust/functional.h>

#include <utility>

namespace cugraph {
namespace reduce_op {

namespace detail {

template <typename T, std::size_t... Is>
__host__ __device__ std::enable_if_t<cugraph::is_thrust_tuple_of_arithmetic<T>::value, T>
elementwise_thrust_min(T lhs, T rhs, std::index_sequence<Is...>)
{
return thrust::make_tuple(
(thrust::get<Is>(lhs) < thrust::get<Is>(rhs) ? thrust::get<Is>(lhs) : thrust::get<Is>(rhs))...);
}

template <typename T, std::size_t... Is>
__host__ __device__ std::enable_if_t<cugraph::is_thrust_tuple_of_arithmetic<T>::value, T>
elementwise_thrust_max(T lhs, T rhs, std::index_sequence<Is...>)
{
return thrust::make_tuple(
(thrust::get<Is>(lhs) < thrust::get<Is>(rhs) ? thrust::get<Is>(rhs) : thrust::get<Is>(lhs))...);
}

} // namespace detail

// Guidance on writing a custom reduction operator.
// 1. It is required to add an "using value_type = type_of_the_reduced_values" statement.
// 2. A custom reduction operator MUST be side-effect free. We use thrust::reduce internally to
Expand Down Expand Up @@ -52,8 +76,8 @@ struct null {
using value_type = void;
};

// Binary reduction operator selecting any of the two input arguments, T should be arithmetic types
// or thrust tuple of arithmetic types.
// Binary reduction operator selecting any of the two input arguments, T should be an arithmetic
// type or a thrust tuple of arithmetic types.
template <typename T>
struct any {
using value_type = T;
Expand All @@ -62,10 +86,13 @@ struct any {
__host__ __device__ T operator()(T const& lhs, T const& rhs) const { return lhs; }
};

template <typename T, typename Enable = void>
struct minimum;

// Binary reduction operator selecting the minimum element of the two input arguments (using
// operator <), T should be arithmetic types or thrust tuple of arithmetic types.
// operator <), a compatible raft comms op exists if T is an arithmetic type.
template <typename T>
struct minimum {
struct minimum<T, std::enable_if_t<std::is_arithmetic_v<T>>> {
using value_type = T;
static constexpr bool pure_function = true; // this can be called in any process
static constexpr raft::comms::op_t compatible_raft_comms_op = raft::comms::op_t::MIN;
Expand All @@ -77,10 +104,55 @@ struct minimum {
}
};

// Binary reduction operator selecting the minimum element of the two input arguments (using
// operator <), a compatible raft comms op does not exist when T is a thrust::tuple type.
template <typename T>
struct minimum<T, std::enable_if_t<cugraph::is_thrust_tuple_of_arithmetic<T>::value>> {
using value_type = T;
static constexpr bool pure_function = true; // this can be called in any process
inline static T const identity_element = max_identity_element<T>();

__host__ __device__ T operator()(T const& lhs, T const& rhs) const
{
return lhs < rhs ? lhs : rhs;
}
};

// Binary reduction operator selecting the minimum element of the two input arguments elementwise
// (using operator < for each element), T should be an arithmetic type (this is identical to
// reduce_op::minimum if T is an arithmetic type) or a thrust tuple of arithmetic types.
template <typename T>
struct elementwise_minimum {
static_assert(cugraph::is_arithmetic_or_thrust_tuple_of_arithmetic<T>::value);

using value_type = T;
static constexpr bool pure_function = true; // this can be called in any process
static constexpr raft::comms::op_t compatible_raft_comms_op = raft::comms::op_t::MIN;
inline static T const identity_element = max_identity_element<T>();

template <typename U = T>
__host__ __device__ std::enable_if_t<std::is_arithmetic_v<U>, T> operator()(T const& lhs,
T const& rhs) const
{
return lhs < rhs ? lhs : rhs;
}

template <typename U = T>
__host__ __device__ std::enable_if_t<cugraph::is_thrust_tuple_of_arithmetic<U>::value, T>
operator()(T const& lhs, T const& rhs) const
{
return detail::elementwise_thrust_min(
lhs, rhs, std::make_index_sequence<thrust::tuple_size<T>::value>());
}
};

template <typename T, typename Enable = void>
struct maximum;

// Binary reduction operator selecting the maximum element of the two input arguments (using
// operator <), T should be arithmetic types or thrust tuple of arithmetic types.
// operator <), a compatible raft comms op exists if T is an arithmetic type.
template <typename T>
struct maximum {
struct maximum<T, std::enable_if_t<std::is_arithmetic_v<T>>> {
using value_type = T;
static constexpr bool pure_function = true; // this can be called in any process
static constexpr raft::comms::op_t compatible_raft_comms_op = raft::comms::op_t::MAX;
Expand All @@ -92,10 +164,54 @@ struct maximum {
}
};

// Binary reduction operator summing the two input arguments, T should be arithmetic types or thrust
// tuple of arithmetic types.
// Binary reduction operator selecting the maximum element of the two input arguments (using
// operator <), a compatible raft comms op does not exist when T is a thrust::tuple type.
template <typename T>
struct maximum<T, std::enable_if_t<cugraph::is_thrust_tuple_of_arithmetic<T>::value>> {
using value_type = T;
static constexpr bool pure_function = true; // this can be called in any process
inline static T const identity_element = min_identity_element<T>();

__host__ __device__ T operator()(T const& lhs, T const& rhs) const
{
return lhs < rhs ? rhs : lhs;
}
};

// Binary reduction operator selecting the maximum element of the two input arguments elementwise
// (using operator < for each element), T should be an arithmetic type (this is identical to
// reduce_op::maximum if T is an arithmetic type) or a thrust tuple of arithmetic types.
template <typename T>
struct elementwise_maximum {
static_assert(cugraph::is_arithmetic_or_thrust_tuple_of_arithmetic<T>::value);

using value_type = T;
static constexpr bool pure_function = true; // this can be called in any process
static constexpr raft::comms::op_t compatible_raft_comms_op = raft::comms::op_t::MAX;
inline static T const identity_element = min_identity_element<T>();

template <typename U = T>
__host__ __device__ std::enable_if_t<std::is_arithmetic_v<U>, T> operator()(T const& lhs,
T const& rhs) const
{
return lhs < rhs ? rhs : lhs;
}

template <typename U = T>
__host__ __device__ std::enable_if_t<cugraph::is_thrust_tuple_of_arithmetic<U>::value, T>
operator()(T const& lhs, T const& rhs) const
{
return detail::elementwise_thrust_max(
lhs, rhs, std::make_index_sequence<thrust::tuple_size<T>::value>());
}
};

// Binary reduction operator summing the two input arguments, T should be an arithmetic type or a
// thrust tuple of arithmetic types.
template <typename T>
struct plus {
static_assert(cugraph::is_arithmetic_or_thrust_tuple_of_arithmetic<T>::value);

using value_type = T;
static constexpr bool pure_function = true; // this can be called in any process
static constexpr raft::comms::op_t compatible_raft_comms_op = raft::comms::op_t::SUM;
Expand Down Expand Up @@ -146,9 +262,9 @@ __device__ std::enable_if_t<has_compatible_raft_comms_op_v<ReduceOp>, void> atom
if constexpr (ReduceOp::compatible_raft_comms_op == raft::comms::op_t::SUM) {
atomic_add_edge_op_result(iter, value);
} else if constexpr (ReduceOp::compatible_raft_comms_op == raft::comms::op_t::MIN) {
atomic_min_edge_op_result(iter, value);
elementwise_atomic_min_edge_op_result(iter, value);
} else {
atomic_max_edge_op_result(iter, value);
elementwise_atomic_max_edge_op_result(iter, value);
}
}

Expand Down
Loading