Skip to content

Commit

Permalink
Refactor binaryop/compiled/util.cpp (#10756)
Browse files Browse the repository at this point in the history
This PR reduces the complexity of compile-time dispatches to resolve long compile times and massive memory usage in `binaryop/compiled.util.cpp`.

The file `binaryop/compiled/util.cpp` exposes two functions: `is_supported_operation(out, lhs, rhs, op)` and `get_common_type(out, lhs, rhs)`. I refactored both of them, since they were both expensive to compile.

In `is_supported_operation`, I replaced a quadruple dispatch (!!!!) on (LHS type, RHS type, binary operation, output type) with a triple dispatch (LHS type, RHS type, BinaryOp) and some runtime single-dispatches to handle the output type.

In `get_common_type`, I replaced a triple type dispatch on (output type, LHS type, RHS type) with a few double type dispatches. I used the definition of `std::common_type` to simplify `std::common_type_t<A, B, C>` into `std::common_type_t<std::common_type_t<A, B>, C>`, which means we can double-dispatch twice and use runtime `data_type` values in between.

**Impact:** Peak memory usage (max resident set size) when compiling this file drops from 14.6 GB (280acdf) to 2.4 GB (ee2c26a), and the time to compile drops from 2:52.48 minutes (280acdf) to 57.91 seconds (ee2c26a).

Authors:
  - Bradley Dice (https://github.com/bdice)

Approvers:
  - Robert Maynard (https://github.com/robertmaynard)
  - Karthikeyan (https://github.com/karthikeyann)

URL: #10756
  • Loading branch information
bdice authored May 11, 2022
1 parent 2b204d0 commit 0cc29a0
Showing 1 changed file with 98 additions and 73 deletions.
171 changes: 98 additions & 73 deletions cpp/src/binaryop/compiled/util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,49 +26,62 @@
namespace cudf::binops::compiled {

namespace {
/**
* @brief Functor that returns optional common type of 2 or 3 given types.
*
*/

struct common_type_functor {
template <typename TypeLhs, typename TypeRhs>
struct nested_common_type_functor {
template <typename TypeOut>
std::optional<data_type> operator()()
{
// If common_type exists
if constexpr (cudf::has_common_type_v<TypeOut, TypeLhs, TypeRhs>) {
using TypeCommon = typename std::common_type<TypeOut, TypeLhs, TypeRhs>::type;
return data_type{type_to_id<TypeCommon>()};
} else if constexpr (cudf::has_common_type_v<TypeLhs, TypeRhs>) {
using TypeCommon = typename std::common_type<TypeLhs, TypeRhs>::type;
// Eg. d=t-t
return data_type{type_to_id<TypeCommon>()};
}

// A compiler bug may cause a compilation error when using empty initializer list to construct
// an std::optional object containing no `data_type` value. Therefore, we should explicitly
// return `std::nullopt` instead.
return std::nullopt;
std::optional<data_type> operator()() const
{
if constexpr (cudf::has_common_type_v<TypeLhs, TypeRhs>) {
using TypeCommon = std::common_type_t<TypeLhs, TypeRhs>;
return data_type{type_to_id<TypeCommon>()};
}
};
template <typename TypeLhs, typename TypeRhs>
std::optional<data_type> operator()(data_type out)

// A compiler bug may cause a compilation error when using empty
// initializer list to construct an std::optional object containing no
// `data_type` value. Therefore, we explicitly return `std::nullopt`
// instead.
return std::nullopt;
}
};

struct has_mutable_element_accessor_functor {
template <typename T>
bool operator()() const
{
return mutable_column_device_view::has_element_accessor<T>();
}
};

bool has_mutable_element_accessor(data_type t)
{
return type_dispatcher(t, has_mutable_element_accessor_functor{});
}

template <typename InputType>
struct is_constructible_functor {
template <typename TargetType>
bool operator()() const
{
return type_dispatcher(out, nested_common_type_functor<TypeLhs, TypeRhs>{});
return std::is_constructible_v<TargetType, InputType>;
}
};

template <typename InputType>
bool is_constructible(data_type target_type)
{
return type_dispatcher(target_type, is_constructible_functor<InputType>{});
}

/**
* @brief Functor that return true if BinaryOperator supports given input and output types.
*
* @tparam BinaryOperator binary operator functor
*/
template <typename BinaryOperator>
struct is_binary_operation_supported {
// For types where Out type is fixed. (eg. comparison types)
// For types where Out type is fixed. (e.g. comparison types)
template <typename TypeLhs, typename TypeRhs>
inline constexpr bool operator()()
inline constexpr bool operator()() const
{
if constexpr (column_device_view::has_element_accessor<TypeLhs>() and
column_device_view::has_element_accessor<TypeRhs>()) {
Expand All @@ -83,24 +96,22 @@ struct is_binary_operation_supported {
}
}

template <typename TypeOut, typename TypeLhs, typename TypeRhs>
inline constexpr bool operator()()
template <typename TypeLhs, typename TypeRhs>
inline constexpr bool operator()(data_type out_type) const
{
if constexpr (column_device_view::has_element_accessor<TypeLhs>() and
column_device_view::has_element_accessor<TypeRhs>() and
(mutable_column_device_view::has_element_accessor<TypeOut>() or
is_fixed_point<TypeOut>())) {
if constexpr (has_common_type_v<TypeLhs, TypeRhs>) {
using common_t = std::common_type_t<TypeLhs, TypeRhs>;
if constexpr (std::is_invocable_v<BinaryOperator, common_t, common_t>) {
using ReturnType = std::invoke_result_t<BinaryOperator, common_t, common_t>;
return std::is_constructible_v<TypeOut, ReturnType> or
(is_fixed_point<ReturnType>() and is_fixed_point<TypeOut>());
}
} else {
if constexpr (std::is_invocable_v<BinaryOperator, TypeLhs, TypeRhs>) {
column_device_view::has_element_accessor<TypeRhs>()) {
if (has_mutable_element_accessor(out_type) or is_fixed_point(out_type)) {
if constexpr (has_common_type_v<TypeLhs, TypeRhs>) {
using common_t = std::common_type_t<TypeLhs, TypeRhs>;
if constexpr (std::is_invocable_v<BinaryOperator, common_t, common_t>) {
using ReturnType = std::invoke_result_t<BinaryOperator, common_t, common_t>;
return is_constructible<ReturnType>(out_type) or
(is_fixed_point<ReturnType>() and is_fixed_point(out_type));
}
} else if constexpr (std::is_invocable_v<BinaryOperator, TypeLhs, TypeRhs>) {
using ReturnType = std::invoke_result_t<BinaryOperator, TypeLhs, TypeRhs>;
return std::is_constructible_v<TypeOut, ReturnType>;
return is_constructible<ReturnType>(out_type);
}
}
}
Expand All @@ -111,37 +122,36 @@ struct is_binary_operation_supported {
struct is_supported_operation_functor {
template <typename TypeLhs, typename TypeRhs>
struct nested_support_functor {
template <typename BinaryOperator, typename TypeOut>
inline constexpr bool call()
template <typename BinaryOperator>
inline constexpr bool call(data_type out_type) const
{
return is_binary_operation_supported<BinaryOperator>{}
.template operator()<TypeOut, TypeLhs, TypeRhs>();
return is_binary_operation_supported<BinaryOperator>{}.template operator()<TypeLhs, TypeRhs>(
out_type);
}
template <typename TypeOut>
inline constexpr bool operator()(binary_operator op)
inline constexpr bool operator()(binary_operator op, data_type out_type) const
{
switch (op) {
// clang-format off
case binary_operator::ADD: return call<ops::Add, TypeOut>();
case binary_operator::SUB: return call<ops::Sub, TypeOut>();
case binary_operator::MUL: return call<ops::Mul, TypeOut>();
case binary_operator::DIV: return call<ops::Div, TypeOut>();
case binary_operator::TRUE_DIV: return call<ops::TrueDiv, TypeOut>();
case binary_operator::FLOOR_DIV: return call<ops::FloorDiv, TypeOut>();
case binary_operator::MOD: return call<ops::Mod, TypeOut>();
case binary_operator::PYMOD: return call<ops::PyMod, TypeOut>();
case binary_operator::POW: return call<ops::Pow, TypeOut>();
case binary_operator::BITWISE_AND: return call<ops::BitwiseAnd, TypeOut>();
case binary_operator::BITWISE_OR: return call<ops::BitwiseOr, TypeOut>();
case binary_operator::BITWISE_XOR: return call<ops::BitwiseXor, TypeOut>();
case binary_operator::SHIFT_LEFT: return call<ops::ShiftLeft, TypeOut>();
case binary_operator::SHIFT_RIGHT: return call<ops::ShiftRight, TypeOut>();
case binary_operator::SHIFT_RIGHT_UNSIGNED: return call<ops::ShiftRightUnsigned, TypeOut>();
case binary_operator::LOG_BASE: return call<ops::LogBase, TypeOut>();
case binary_operator::ATAN2: return call<ops::ATan2, TypeOut>();
case binary_operator::PMOD: return call<ops::PMod, TypeOut>();
case binary_operator::NULL_MAX: return call<ops::NullMax, TypeOut>();
case binary_operator::NULL_MIN: return call<ops::NullMin, TypeOut>();
case binary_operator::ADD: return call<ops::Add>(out_type);
case binary_operator::SUB: return call<ops::Sub>(out_type);
case binary_operator::MUL: return call<ops::Mul>(out_type);
case binary_operator::DIV: return call<ops::Div>(out_type);
case binary_operator::TRUE_DIV: return call<ops::TrueDiv>(out_type);
case binary_operator::FLOOR_DIV: return call<ops::FloorDiv>(out_type);
case binary_operator::MOD: return call<ops::Mod>(out_type);
case binary_operator::PYMOD: return call<ops::PyMod>(out_type);
case binary_operator::POW: return call<ops::Pow>(out_type);
case binary_operator::BITWISE_AND: return call<ops::BitwiseAnd>(out_type);
case binary_operator::BITWISE_OR: return call<ops::BitwiseOr>(out_type);
case binary_operator::BITWISE_XOR: return call<ops::BitwiseXor>(out_type);
case binary_operator::SHIFT_LEFT: return call<ops::ShiftLeft>(out_type);
case binary_operator::SHIFT_RIGHT: return call<ops::ShiftRight>(out_type);
case binary_operator::SHIFT_RIGHT_UNSIGNED: return call<ops::ShiftRightUnsigned>(out_type);
case binary_operator::LOG_BASE: return call<ops::LogBase>(out_type);
case binary_operator::ATAN2: return call<ops::ATan2>(out_type);
case binary_operator::PMOD: return call<ops::PMod>(out_type);
case binary_operator::NULL_MAX: return call<ops::NullMax>(out_type);
case binary_operator::NULL_MIN: return call<ops::NullMin>(out_type);
/*
case binary_operator::GENERIC_BINARY: // defined in jit only.
*/
Expand All @@ -152,13 +162,13 @@ struct is_supported_operation_functor {
};

template <typename BinaryOperator, typename TypeLhs, typename TypeRhs>
inline constexpr bool bool_op(data_type out)
inline constexpr bool bool_op(data_type out) const
{
return out.id() == type_id::BOOL8 and
is_binary_operation_supported<BinaryOperator>{}.template operator()<TypeLhs, TypeRhs>();
}
template <typename TypeLhs, typename TypeRhs>
inline constexpr bool operator()(data_type out, binary_operator op)
inline constexpr bool operator()(data_type out, binary_operator op) const
{
switch (op) {
// output type should be bool type.
Expand All @@ -175,7 +185,7 @@ struct is_supported_operation_functor {
return bool_op<ops::NullLogicalAnd, TypeLhs, TypeRhs>(out);
case binary_operator::NULL_LOGICAL_OR:
return bool_op<ops::NullLogicalOr, TypeLhs, TypeRhs>(out);
default: return type_dispatcher(out, nested_support_functor<TypeLhs, TypeRhs>{}, op);
default: return nested_support_functor<TypeLhs, TypeRhs>{}(op, out);
}
return false;
}
Expand All @@ -185,7 +195,22 @@ struct is_supported_operation_functor {

std::optional<data_type> get_common_type(data_type out, data_type lhs, data_type rhs)
{
return double_type_dispatcher(lhs, rhs, common_type_functor{}, out);
// Compute the common type of (out, lhs, rhs) if it exists, or the common
// type of (lhs, rhs) if it exists, else return a null optional.
// We can avoid a triple type dispatch by using the definition of
// std::common_type to compute this with double type dispatches.
// Specifically, std::common_type_t<TypeOut, TypeLhs, TypeRhs> is the same as
// std::common_type_t<std::common_type_t<TypeOut, TypeLhs>, TypeRhs>.
auto common_type = double_type_dispatcher(out, lhs, common_type_functor{});
if (common_type.has_value()) {
common_type = double_type_dispatcher(common_type.value(), rhs, common_type_functor{});
}
// If no common type of (out, lhs, rhs) exists, fall back to the common type
// of (lhs, rhs).
if (!common_type.has_value()) {
common_type = double_type_dispatcher(lhs, rhs, common_type_functor{});
}
return common_type;
}

bool is_supported_operation(data_type out, data_type lhs, data_type rhs, binary_operator op)
Expand Down

0 comments on commit 0cc29a0

Please sign in to comment.