-
Notifications
You must be signed in to change notification settings - Fork 915
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor binaryop/compiled/util.cpp #10756
Changes from 6 commits
bd9b0a6
0c10a37
6f068da
ee2c26a
1c256fb
e2e2032
be5537e
1b9656c
ecfadd1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,47 +26,60 @@ | |
namespace cudf::binops::compiled { | ||
|
||
namespace { | ||
/** | ||
* @brief Functor that returns optional common type of 2 or 3 given types. | ||
* | ||
*/ | ||
|
||
struct common_type_functor { | ||
template <typename TypeLhs, typename TypeRhs> | ||
struct nested_common_type_functor { | ||
template <typename TypeOut> | ||
std::optional<data_type> operator()() | ||
{ | ||
// If common_type exists | ||
if constexpr (cudf::has_common_type_v<TypeOut, TypeLhs, TypeRhs>) { | ||
using TypeCommon = typename std::common_type<TypeOut, TypeLhs, TypeRhs>::type; | ||
return data_type{type_to_id<TypeCommon>()}; | ||
} else if constexpr (cudf::has_common_type_v<TypeLhs, TypeRhs>) { | ||
using TypeCommon = typename std::common_type<TypeLhs, TypeRhs>::type; | ||
// Eg. d=t-t | ||
return data_type{type_to_id<TypeCommon>()}; | ||
} | ||
|
||
// A compiler bug may cause a compilation error when using empty initializer list to construct | ||
// an std::optional object containing no `data_type` value. Therefore, we should explicitly | ||
// return `std::nullopt` instead. | ||
return std::nullopt; | ||
std::optional<data_type> operator()() | ||
bdice marked this conversation as resolved.
Show resolved
Hide resolved
|
||
{ | ||
if constexpr (cudf::has_common_type_v<TypeLhs, TypeRhs>) { | ||
using TypeCommon = std::common_type_t<TypeLhs, TypeRhs>; | ||
return data_type{type_to_id<TypeCommon>()}; | ||
} | ||
}; | ||
template <typename TypeLhs, typename TypeRhs> | ||
std::optional<data_type> operator()(data_type out) | ||
|
||
// A compiler bug may cause a compilation error when using empty | ||
// initializer list to construct an std::optional object containing no | ||
// `data_type` value. Therefore, we explicitly return `std::nullopt` | ||
// instead. | ||
return std::nullopt; | ||
} | ||
}; | ||
|
||
struct has_mutable_element_accessor_functor { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We introduce two new functors and methods to dispatch them ( |
||
template <typename T> | ||
bool operator()() | ||
bdice marked this conversation as resolved.
Show resolved
Hide resolved
|
||
{ | ||
return type_dispatcher(out, nested_common_type_functor<TypeLhs, TypeRhs>{}); | ||
return mutable_column_device_view::has_element_accessor<T>(); | ||
} | ||
}; | ||
|
||
bool has_mutable_element_accessor(data_type t) | ||
{ | ||
return type_dispatcher(t, has_mutable_element_accessor_functor{}); | ||
} | ||
|
||
template <typename InputType> | ||
struct is_constructible_functor { | ||
template <typename TargetType> | ||
bool operator()() | ||
bdice marked this conversation as resolved.
Show resolved
Hide resolved
|
||
{ | ||
return std::is_constructible_v<TargetType, InputType>; | ||
} | ||
}; | ||
|
||
template <typename InputType> | ||
bool is_constructible(data_type target_type) | ||
{ | ||
return type_dispatcher(target_type, is_constructible_functor<InputType>{}); | ||
} | ||
|
||
/** | ||
* @brief Functor that return true if BinaryOperator supports given input and output types. | ||
* | ||
* @tparam BinaryOperator binary operator functor | ||
*/ | ||
template <typename BinaryOperator> | ||
struct is_binary_operation_supported { | ||
// For types where Out type is fixed. (eg. comparison types) | ||
// For types where Out type is fixed. (e.g. comparison types) | ||
template <typename TypeLhs, typename TypeRhs> | ||
inline constexpr bool operator()() | ||
bdice marked this conversation as resolved.
Show resolved
Hide resolved
|
||
{ | ||
|
@@ -83,24 +96,22 @@ struct is_binary_operation_supported { | |
} | ||
} | ||
|
||
template <typename TypeOut, typename TypeLhs, typename TypeRhs> | ||
inline constexpr bool operator()() | ||
template <typename TypeLhs, typename TypeRhs> | ||
inline constexpr bool operator()(data_type out_type) | ||
bdice marked this conversation as resolved.
Show resolved
Hide resolved
|
||
{ | ||
if constexpr (column_device_view::has_element_accessor<TypeLhs>() and | ||
column_device_view::has_element_accessor<TypeRhs>() and | ||
(mutable_column_device_view::has_element_accessor<TypeOut>() or | ||
is_fixed_point<TypeOut>())) { | ||
if constexpr (has_common_type_v<TypeLhs, TypeRhs>) { | ||
using common_t = std::common_type_t<TypeLhs, TypeRhs>; | ||
if constexpr (std::is_invocable_v<BinaryOperator, common_t, common_t>) { | ||
using ReturnType = std::invoke_result_t<BinaryOperator, common_t, common_t>; | ||
return std::is_constructible_v<TypeOut, ReturnType> or | ||
(is_fixed_point<ReturnType>() and is_fixed_point<TypeOut>()); | ||
} | ||
} else { | ||
if constexpr (std::is_invocable_v<BinaryOperator, TypeLhs, TypeRhs>) { | ||
column_device_view::has_element_accessor<TypeRhs>()) { | ||
if (has_mutable_element_accessor(out_type) or is_fixed_point(out_type)) { | ||
if constexpr (has_common_type_v<TypeLhs, TypeRhs>) { | ||
using common_t = std::common_type_t<TypeLhs, TypeRhs>; | ||
if constexpr (std::is_invocable_v<BinaryOperator, common_t, common_t>) { | ||
using ReturnType = std::invoke_result_t<BinaryOperator, common_t, common_t>; | ||
return is_constructible<ReturnType>(out_type) or | ||
(is_fixed_point<ReturnType>() and is_fixed_point(out_type)); | ||
} | ||
} else if constexpr (std::is_invocable_v<BinaryOperator, TypeLhs, TypeRhs>) { | ||
using ReturnType = std::invoke_result_t<BinaryOperator, TypeLhs, TypeRhs>; | ||
return std::is_constructible_v<TypeOut, ReturnType>; | ||
return is_constructible<ReturnType>(out_type); | ||
} | ||
} | ||
} | ||
|
@@ -111,37 +122,36 @@ struct is_binary_operation_supported { | |
struct is_supported_operation_functor { | ||
template <typename TypeLhs, typename TypeRhs> | ||
struct nested_support_functor { | ||
template <typename BinaryOperator, typename TypeOut> | ||
inline constexpr bool call() | ||
template <typename BinaryOperator> | ||
inline constexpr bool call(data_type out_type) | ||
bdice marked this conversation as resolved.
Show resolved
Hide resolved
|
||
{ | ||
return is_binary_operation_supported<BinaryOperator>{} | ||
.template operator()<TypeOut, TypeLhs, TypeRhs>(); | ||
return is_binary_operation_supported<BinaryOperator>{}.template operator()<TypeLhs, TypeRhs>( | ||
out_type); | ||
} | ||
template <typename TypeOut> | ||
inline constexpr bool operator()(binary_operator op) | ||
inline constexpr bool operator()(binary_operator op, data_type out_type) | ||
bdice marked this conversation as resolved.
Show resolved
Hide resolved
|
||
{ | ||
switch (op) { | ||
// clang-format off | ||
case binary_operator::ADD: return call<ops::Add, TypeOut>(); | ||
case binary_operator::SUB: return call<ops::Sub, TypeOut>(); | ||
case binary_operator::MUL: return call<ops::Mul, TypeOut>(); | ||
case binary_operator::DIV: return call<ops::Div, TypeOut>(); | ||
case binary_operator::TRUE_DIV: return call<ops::TrueDiv, TypeOut>(); | ||
case binary_operator::FLOOR_DIV: return call<ops::FloorDiv, TypeOut>(); | ||
case binary_operator::MOD: return call<ops::Mod, TypeOut>(); | ||
case binary_operator::PYMOD: return call<ops::PyMod, TypeOut>(); | ||
case binary_operator::POW: return call<ops::Pow, TypeOut>(); | ||
case binary_operator::BITWISE_AND: return call<ops::BitwiseAnd, TypeOut>(); | ||
case binary_operator::BITWISE_OR: return call<ops::BitwiseOr, TypeOut>(); | ||
case binary_operator::BITWISE_XOR: return call<ops::BitwiseXor, TypeOut>(); | ||
case binary_operator::SHIFT_LEFT: return call<ops::ShiftLeft, TypeOut>(); | ||
case binary_operator::SHIFT_RIGHT: return call<ops::ShiftRight, TypeOut>(); | ||
case binary_operator::SHIFT_RIGHT_UNSIGNED: return call<ops::ShiftRightUnsigned, TypeOut>(); | ||
case binary_operator::LOG_BASE: return call<ops::LogBase, TypeOut>(); | ||
case binary_operator::ATAN2: return call<ops::ATan2, TypeOut>(); | ||
case binary_operator::PMOD: return call<ops::PMod, TypeOut>(); | ||
case binary_operator::NULL_MAX: return call<ops::NullMax, TypeOut>(); | ||
case binary_operator::NULL_MIN: return call<ops::NullMin, TypeOut>(); | ||
case binary_operator::ADD: return call<ops::Add>(out_type); | ||
case binary_operator::SUB: return call<ops::Sub>(out_type); | ||
case binary_operator::MUL: return call<ops::Mul>(out_type); | ||
case binary_operator::DIV: return call<ops::Div>(out_type); | ||
case binary_operator::TRUE_DIV: return call<ops::TrueDiv>(out_type); | ||
case binary_operator::FLOOR_DIV: return call<ops::FloorDiv>(out_type); | ||
case binary_operator::MOD: return call<ops::Mod>(out_type); | ||
case binary_operator::PYMOD: return call<ops::PyMod>(out_type); | ||
case binary_operator::POW: return call<ops::Pow>(out_type); | ||
case binary_operator::BITWISE_AND: return call<ops::BitwiseAnd>(out_type); | ||
case binary_operator::BITWISE_OR: return call<ops::BitwiseOr>(out_type); | ||
case binary_operator::BITWISE_XOR: return call<ops::BitwiseXor>(out_type); | ||
case binary_operator::SHIFT_LEFT: return call<ops::ShiftLeft>(out_type); | ||
case binary_operator::SHIFT_RIGHT: return call<ops::ShiftRight>(out_type); | ||
case binary_operator::SHIFT_RIGHT_UNSIGNED: return call<ops::ShiftRightUnsigned>(out_type); | ||
case binary_operator::LOG_BASE: return call<ops::LogBase>(out_type); | ||
case binary_operator::ATAN2: return call<ops::ATan2>(out_type); | ||
case binary_operator::PMOD: return call<ops::PMod>(out_type); | ||
case binary_operator::NULL_MAX: return call<ops::NullMax>(out_type); | ||
case binary_operator::NULL_MIN: return call<ops::NullMin>(out_type); | ||
/* | ||
case binary_operator::GENERIC_BINARY: // defined in jit only. | ||
*/ | ||
|
@@ -175,7 +185,7 @@ struct is_supported_operation_functor { | |
return bool_op<ops::NullLogicalAnd, TypeLhs, TypeRhs>(out); | ||
case binary_operator::NULL_LOGICAL_OR: | ||
return bool_op<ops::NullLogicalOr, TypeLhs, TypeRhs>(out); | ||
default: return type_dispatcher(out, nested_support_functor<TypeLhs, TypeRhs>{}, op); | ||
default: return nested_support_functor<TypeLhs, TypeRhs>{}(op, out); | ||
} | ||
return false; | ||
} | ||
|
@@ -185,7 +195,22 @@ struct is_supported_operation_functor { | |
|
||
std::optional<data_type> get_common_type(data_type out, data_type lhs, data_type rhs) | ||
{ | ||
return double_type_dispatcher(lhs, rhs, common_type_functor{}, out); | ||
// Compute the common type of (out, lhs, rhs) if it exists, or the common | ||
// type of (lhs, rhs) if it exists, else return a null optional. | ||
// We can avoid a triple type dispatch by using the definition of | ||
// std::common_type to compute this with double type dispatches. | ||
// Specifically, std::common_type_t<TypeOut, TypeLhs, TypeRhs> is the same as | ||
// std::common_type_t<std::common_type_t<TypeOut, TypeLhs>, TypeRhs>. | ||
auto common_type = double_type_dispatcher(out, lhs, common_type_functor{}); | ||
if (common_type.has_value()) { | ||
common_type = double_type_dispatcher(common_type.value(), rhs, common_type_functor{}); | ||
} | ||
// If no common type of (out, lhs, rhs) exists, fall back to the common type | ||
// of (lhs, rhs). | ||
if (!common_type.has_value()) { | ||
common_type = double_type_dispatcher(lhs, rhs, common_type_functor{}); | ||
} | ||
return common_type; | ||
} | ||
|
||
bool is_supported_operation(data_type out, data_type lhs, data_type rhs, binary_operator op) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is similar to the old code, except we remove a layer of dispatch (2 instead of 3 types). We make up for this in
get_common_type
by doing multiple double-dispatches with this functor in a way that is equivalent to the rules ofstd::common_type_t
for multiple input types.