Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor binaryop/compiled/util.cpp #10756

Merged
165 changes: 95 additions & 70 deletions cpp/src/binaryop/compiled/util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,47 +26,60 @@
namespace cudf::binops::compiled {

namespace {
/**
* @brief Functor that returns optional common type of 2 or 3 given types.
*
*/

struct common_type_functor {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is similar to the old code, except we remove a layer of dispatch (2 instead of 3 types). We make up for this in get_common_type by doing multiple double-dispatches with this functor in a way that is equivalent to the rules of std::common_type_t for multiple input types.

template <typename TypeLhs, typename TypeRhs>
struct nested_common_type_functor {
template <typename TypeOut>
std::optional<data_type> operator()()
{
// If common_type exists
if constexpr (cudf::has_common_type_v<TypeOut, TypeLhs, TypeRhs>) {
using TypeCommon = typename std::common_type<TypeOut, TypeLhs, TypeRhs>::type;
return data_type{type_to_id<TypeCommon>()};
} else if constexpr (cudf::has_common_type_v<TypeLhs, TypeRhs>) {
using TypeCommon = typename std::common_type<TypeLhs, TypeRhs>::type;
// Eg. d=t-t
return data_type{type_to_id<TypeCommon>()};
}

// A compiler bug may cause a compilation error when using empty initializer list to construct
// an std::optional object containing no `data_type` value. Therefore, we should explicitly
// return `std::nullopt` instead.
return std::nullopt;
std::optional<data_type> operator()()
bdice marked this conversation as resolved.
Show resolved Hide resolved
{
if constexpr (cudf::has_common_type_v<TypeLhs, TypeRhs>) {
using TypeCommon = std::common_type_t<TypeLhs, TypeRhs>;
return data_type{type_to_id<TypeCommon>()};
}
};
template <typename TypeLhs, typename TypeRhs>
std::optional<data_type> operator()(data_type out)

// A compiler bug may cause a compilation error when using empty
// initializer list to construct an std::optional object containing no
// `data_type` value. Therefore, we explicitly return `std::nullopt`
// instead.
return std::nullopt;
}
};

struct has_mutable_element_accessor_functor {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We introduce two new functors and methods to dispatch them (has_mutable_element_accessor, is_constructible) that we can use at runtime based on the value of data_type out. The method is_constructible is templated on ReturnType because it is known at compile time based on the invoked return type of the binary operator. However, the output type can be used as runtime information with a single dispatch.

template <typename T>
bool operator()()
bdice marked this conversation as resolved.
Show resolved Hide resolved
{
return type_dispatcher(out, nested_common_type_functor<TypeLhs, TypeRhs>{});
return mutable_column_device_view::has_element_accessor<T>();
}
};

bool has_mutable_element_accessor(data_type t)
{
return type_dispatcher(t, has_mutable_element_accessor_functor{});
}

template <typename InputType>
struct is_constructible_functor {
template <typename TargetType>
bool operator()()
bdice marked this conversation as resolved.
Show resolved Hide resolved
{
return std::is_constructible_v<TargetType, InputType>;
}
};

template <typename InputType>
bool is_constructible(data_type target_type)
{
return type_dispatcher(target_type, is_constructible_functor<InputType>{});
}

/**
* @brief Functor that return true if BinaryOperator supports given input and output types.
*
* @tparam BinaryOperator binary operator functor
*/
template <typename BinaryOperator>
struct is_binary_operation_supported {
// For types where Out type is fixed. (eg. comparison types)
// For types where Out type is fixed. (e.g. comparison types)
template <typename TypeLhs, typename TypeRhs>
inline constexpr bool operator()()
bdice marked this conversation as resolved.
Show resolved Hide resolved
{
Expand All @@ -83,24 +96,22 @@ struct is_binary_operation_supported {
}
}

template <typename TypeOut, typename TypeLhs, typename TypeRhs>
inline constexpr bool operator()()
template <typename TypeLhs, typename TypeRhs>
inline constexpr bool operator()(data_type out_type)
bdice marked this conversation as resolved.
Show resolved Hide resolved
{
if constexpr (column_device_view::has_element_accessor<TypeLhs>() and
column_device_view::has_element_accessor<TypeRhs>() and
(mutable_column_device_view::has_element_accessor<TypeOut>() or
is_fixed_point<TypeOut>())) {
if constexpr (has_common_type_v<TypeLhs, TypeRhs>) {
using common_t = std::common_type_t<TypeLhs, TypeRhs>;
if constexpr (std::is_invocable_v<BinaryOperator, common_t, common_t>) {
using ReturnType = std::invoke_result_t<BinaryOperator, common_t, common_t>;
return std::is_constructible_v<TypeOut, ReturnType> or
(is_fixed_point<ReturnType>() and is_fixed_point<TypeOut>());
}
} else {
if constexpr (std::is_invocable_v<BinaryOperator, TypeLhs, TypeRhs>) {
column_device_view::has_element_accessor<TypeRhs>()) {
if (has_mutable_element_accessor(out_type) or is_fixed_point(out_type)) {
if constexpr (has_common_type_v<TypeLhs, TypeRhs>) {
using common_t = std::common_type_t<TypeLhs, TypeRhs>;
if constexpr (std::is_invocable_v<BinaryOperator, common_t, common_t>) {
using ReturnType = std::invoke_result_t<BinaryOperator, common_t, common_t>;
return is_constructible<ReturnType>(out_type) or
(is_fixed_point<ReturnType>() and is_fixed_point(out_type));
}
} else if constexpr (std::is_invocable_v<BinaryOperator, TypeLhs, TypeRhs>) {
using ReturnType = std::invoke_result_t<BinaryOperator, TypeLhs, TypeRhs>;
return std::is_constructible_v<TypeOut, ReturnType>;
return is_constructible<ReturnType>(out_type);
}
}
}
Expand All @@ -111,37 +122,36 @@ struct is_binary_operation_supported {
struct is_supported_operation_functor {
template <typename TypeLhs, typename TypeRhs>
struct nested_support_functor {
template <typename BinaryOperator, typename TypeOut>
inline constexpr bool call()
template <typename BinaryOperator>
inline constexpr bool call(data_type out_type)
bdice marked this conversation as resolved.
Show resolved Hide resolved
{
return is_binary_operation_supported<BinaryOperator>{}
.template operator()<TypeOut, TypeLhs, TypeRhs>();
return is_binary_operation_supported<BinaryOperator>{}.template operator()<TypeLhs, TypeRhs>(
out_type);
}
template <typename TypeOut>
inline constexpr bool operator()(binary_operator op)
inline constexpr bool operator()(binary_operator op, data_type out_type)
bdice marked this conversation as resolved.
Show resolved Hide resolved
{
switch (op) {
// clang-format off
case binary_operator::ADD: return call<ops::Add, TypeOut>();
case binary_operator::SUB: return call<ops::Sub, TypeOut>();
case binary_operator::MUL: return call<ops::Mul, TypeOut>();
case binary_operator::DIV: return call<ops::Div, TypeOut>();
case binary_operator::TRUE_DIV: return call<ops::TrueDiv, TypeOut>();
case binary_operator::FLOOR_DIV: return call<ops::FloorDiv, TypeOut>();
case binary_operator::MOD: return call<ops::Mod, TypeOut>();
case binary_operator::PYMOD: return call<ops::PyMod, TypeOut>();
case binary_operator::POW: return call<ops::Pow, TypeOut>();
case binary_operator::BITWISE_AND: return call<ops::BitwiseAnd, TypeOut>();
case binary_operator::BITWISE_OR: return call<ops::BitwiseOr, TypeOut>();
case binary_operator::BITWISE_XOR: return call<ops::BitwiseXor, TypeOut>();
case binary_operator::SHIFT_LEFT: return call<ops::ShiftLeft, TypeOut>();
case binary_operator::SHIFT_RIGHT: return call<ops::ShiftRight, TypeOut>();
case binary_operator::SHIFT_RIGHT_UNSIGNED: return call<ops::ShiftRightUnsigned, TypeOut>();
case binary_operator::LOG_BASE: return call<ops::LogBase, TypeOut>();
case binary_operator::ATAN2: return call<ops::ATan2, TypeOut>();
case binary_operator::PMOD: return call<ops::PMod, TypeOut>();
case binary_operator::NULL_MAX: return call<ops::NullMax, TypeOut>();
case binary_operator::NULL_MIN: return call<ops::NullMin, TypeOut>();
case binary_operator::ADD: return call<ops::Add>(out_type);
case binary_operator::SUB: return call<ops::Sub>(out_type);
case binary_operator::MUL: return call<ops::Mul>(out_type);
case binary_operator::DIV: return call<ops::Div>(out_type);
case binary_operator::TRUE_DIV: return call<ops::TrueDiv>(out_type);
case binary_operator::FLOOR_DIV: return call<ops::FloorDiv>(out_type);
case binary_operator::MOD: return call<ops::Mod>(out_type);
case binary_operator::PYMOD: return call<ops::PyMod>(out_type);
case binary_operator::POW: return call<ops::Pow>(out_type);
case binary_operator::BITWISE_AND: return call<ops::BitwiseAnd>(out_type);
case binary_operator::BITWISE_OR: return call<ops::BitwiseOr>(out_type);
case binary_operator::BITWISE_XOR: return call<ops::BitwiseXor>(out_type);
case binary_operator::SHIFT_LEFT: return call<ops::ShiftLeft>(out_type);
case binary_operator::SHIFT_RIGHT: return call<ops::ShiftRight>(out_type);
case binary_operator::SHIFT_RIGHT_UNSIGNED: return call<ops::ShiftRightUnsigned>(out_type);
case binary_operator::LOG_BASE: return call<ops::LogBase>(out_type);
case binary_operator::ATAN2: return call<ops::ATan2>(out_type);
case binary_operator::PMOD: return call<ops::PMod>(out_type);
case binary_operator::NULL_MAX: return call<ops::NullMax>(out_type);
case binary_operator::NULL_MIN: return call<ops::NullMin>(out_type);
/*
case binary_operator::GENERIC_BINARY: // defined in jit only.
*/
Expand Down Expand Up @@ -175,7 +185,7 @@ struct is_supported_operation_functor {
return bool_op<ops::NullLogicalAnd, TypeLhs, TypeRhs>(out);
case binary_operator::NULL_LOGICAL_OR:
return bool_op<ops::NullLogicalOr, TypeLhs, TypeRhs>(out);
default: return type_dispatcher(out, nested_support_functor<TypeLhs, TypeRhs>{}, op);
default: return nested_support_functor<TypeLhs, TypeRhs>{}(op, out);
}
return false;
}
Expand All @@ -185,7 +195,22 @@ struct is_supported_operation_functor {

std::optional<data_type> get_common_type(data_type out, data_type lhs, data_type rhs)
{
return double_type_dispatcher(lhs, rhs, common_type_functor{}, out);
// Compute the common type of (out, lhs, rhs) if it exists, or the common
// type of (lhs, rhs) if it exists, else return a null optional.
// We can avoid a triple type dispatch by using the definition of
// std::common_type to compute this with double type dispatches.
// Specifically, std::common_type_t<TypeOut, TypeLhs, TypeRhs> is the same as
// std::common_type_t<std::common_type_t<TypeOut, TypeLhs>, TypeRhs>.
auto common_type = double_type_dispatcher(out, lhs, common_type_functor{});
if (common_type.has_value()) {
common_type = double_type_dispatcher(common_type.value(), rhs, common_type_functor{});
}
// If no common type of (out, lhs, rhs) exists, fall back to the common type
// of (lhs, rhs).
if (!common_type.has_value()) {
common_type = double_type_dispatcher(lhs, rhs, common_type_functor{});
}
return common_type;
}

bool is_supported_operation(data_type out, data_type lhs, data_type rhs, binary_operator op)
Expand Down