diff --git a/cpp/src/binaryop/compiled/util.cpp b/cpp/src/binaryop/compiled/util.cpp index d8f1eb03a16..91fa04be6e2 100644 --- a/cpp/src/binaryop/compiled/util.cpp +++ b/cpp/src/binaryop/compiled/util.cpp @@ -26,39 +26,52 @@ namespace cudf::binops::compiled { namespace { -/** - * @brief Functor that returns optional common type of 2 or 3 given types. - * - */ + struct common_type_functor { template - struct nested_common_type_functor { - template - std::optional operator()() - { - // If common_type exists - if constexpr (cudf::has_common_type_v) { - using TypeCommon = typename std::common_type::type; - return data_type{type_to_id()}; - } else if constexpr (cudf::has_common_type_v) { - using TypeCommon = typename std::common_type::type; - // Eg. d=t-t - return data_type{type_to_id()}; - } - - // A compiler bug may cause a compilation error when using empty initializer list to construct - // an std::optional object containing no `data_type` value. Therefore, we should explicitly - // return `std::nullopt` instead. - return std::nullopt; + std::optional operator()() const + { + if constexpr (cudf::has_common_type_v) { + using TypeCommon = std::common_type_t; + return data_type{type_to_id()}; } - }; - template - std::optional operator()(data_type out) + + // A compiler bug may cause a compilation error when using empty + // initializer list to construct an std::optional object containing no + // `data_type` value. Therefore, we explicitly return `std::nullopt` + // instead. + return std::nullopt; + } +}; + +struct has_mutable_element_accessor_functor { + template + bool operator()() const + { + return mutable_column_device_view::has_element_accessor(); + } +}; + +bool has_mutable_element_accessor(data_type t) +{ + return type_dispatcher(t, has_mutable_element_accessor_functor{}); +} + +template +struct is_constructible_functor { + template + bool operator()() const { - return type_dispatcher(out, nested_common_type_functor{}); + return std::is_constructible_v; } }; +template +bool is_constructible(data_type target_type) +{ + return type_dispatcher(target_type, is_constructible_functor{}); +} + /** * @brief Functor that return true if BinaryOperator supports given input and output types. * @@ -66,9 +79,9 @@ struct common_type_functor { */ template struct is_binary_operation_supported { - // For types where Out type is fixed. (eg. comparison types) + // For types where Out type is fixed. (e.g. comparison types) template - inline constexpr bool operator()() + inline constexpr bool operator()() const { if constexpr (column_device_view::has_element_accessor() and column_device_view::has_element_accessor()) { @@ -83,24 +96,22 @@ struct is_binary_operation_supported { } } - template - inline constexpr bool operator()() + template + inline constexpr bool operator()(data_type out_type) const { if constexpr (column_device_view::has_element_accessor() and - column_device_view::has_element_accessor() and - (mutable_column_device_view::has_element_accessor() or - is_fixed_point())) { - if constexpr (has_common_type_v) { - using common_t = std::common_type_t; - if constexpr (std::is_invocable_v) { - using ReturnType = std::invoke_result_t; - return std::is_constructible_v or - (is_fixed_point() and is_fixed_point()); - } - } else { - if constexpr (std::is_invocable_v) { + column_device_view::has_element_accessor()) { + if (has_mutable_element_accessor(out_type) or is_fixed_point(out_type)) { + if constexpr (has_common_type_v) { + using common_t = std::common_type_t; + if constexpr (std::is_invocable_v) { + using ReturnType = std::invoke_result_t; + return is_constructible(out_type) or + (is_fixed_point() and is_fixed_point(out_type)); + } + } else if constexpr (std::is_invocable_v) { using ReturnType = std::invoke_result_t; - return std::is_constructible_v; + return is_constructible(out_type); } } } @@ -111,37 +122,36 @@ struct is_binary_operation_supported { struct is_supported_operation_functor { template struct nested_support_functor { - template - inline constexpr bool call() + template + inline constexpr bool call(data_type out_type) const { - return is_binary_operation_supported{} - .template operator()(); + return is_binary_operation_supported{}.template operator()( + out_type); } - template - inline constexpr bool operator()(binary_operator op) + inline constexpr bool operator()(binary_operator op, data_type out_type) const { switch (op) { // clang-format off - case binary_operator::ADD: return call(); - case binary_operator::SUB: return call(); - case binary_operator::MUL: return call(); - case binary_operator::DIV: return call(); - case binary_operator::TRUE_DIV: return call(); - case binary_operator::FLOOR_DIV: return call(); - case binary_operator::MOD: return call(); - case binary_operator::PYMOD: return call(); - case binary_operator::POW: return call(); - case binary_operator::BITWISE_AND: return call(); - case binary_operator::BITWISE_OR: return call(); - case binary_operator::BITWISE_XOR: return call(); - case binary_operator::SHIFT_LEFT: return call(); - case binary_operator::SHIFT_RIGHT: return call(); - case binary_operator::SHIFT_RIGHT_UNSIGNED: return call(); - case binary_operator::LOG_BASE: return call(); - case binary_operator::ATAN2: return call(); - case binary_operator::PMOD: return call(); - case binary_operator::NULL_MAX: return call(); - case binary_operator::NULL_MIN: return call(); + case binary_operator::ADD: return call(out_type); + case binary_operator::SUB: return call(out_type); + case binary_operator::MUL: return call(out_type); + case binary_operator::DIV: return call(out_type); + case binary_operator::TRUE_DIV: return call(out_type); + case binary_operator::FLOOR_DIV: return call(out_type); + case binary_operator::MOD: return call(out_type); + case binary_operator::PYMOD: return call(out_type); + case binary_operator::POW: return call(out_type); + case binary_operator::BITWISE_AND: return call(out_type); + case binary_operator::BITWISE_OR: return call(out_type); + case binary_operator::BITWISE_XOR: return call(out_type); + case binary_operator::SHIFT_LEFT: return call(out_type); + case binary_operator::SHIFT_RIGHT: return call(out_type); + case binary_operator::SHIFT_RIGHT_UNSIGNED: return call(out_type); + case binary_operator::LOG_BASE: return call(out_type); + case binary_operator::ATAN2: return call(out_type); + case binary_operator::PMOD: return call(out_type); + case binary_operator::NULL_MAX: return call(out_type); + case binary_operator::NULL_MIN: return call(out_type); /* case binary_operator::GENERIC_BINARY: // defined in jit only. */ @@ -152,13 +162,13 @@ struct is_supported_operation_functor { }; template - inline constexpr bool bool_op(data_type out) + inline constexpr bool bool_op(data_type out) const { return out.id() == type_id::BOOL8 and is_binary_operation_supported{}.template operator()(); } template - inline constexpr bool operator()(data_type out, binary_operator op) + inline constexpr bool operator()(data_type out, binary_operator op) const { switch (op) { // output type should be bool type. @@ -175,7 +185,7 @@ struct is_supported_operation_functor { return bool_op(out); case binary_operator::NULL_LOGICAL_OR: return bool_op(out); - default: return type_dispatcher(out, nested_support_functor{}, op); + default: return nested_support_functor{}(op, out); } return false; } @@ -185,7 +195,22 @@ struct is_supported_operation_functor { std::optional get_common_type(data_type out, data_type lhs, data_type rhs) { - return double_type_dispatcher(lhs, rhs, common_type_functor{}, out); + // Compute the common type of (out, lhs, rhs) if it exists, or the common + // type of (lhs, rhs) if it exists, else return a null optional. + // We can avoid a triple type dispatch by using the definition of + // std::common_type to compute this with double type dispatches. + // Specifically, std::common_type_t is the same as + // std::common_type_t, TypeRhs>. + auto common_type = double_type_dispatcher(out, lhs, common_type_functor{}); + if (common_type.has_value()) { + common_type = double_type_dispatcher(common_type.value(), rhs, common_type_functor{}); + } + // If no common type of (out, lhs, rhs) exists, fall back to the common type + // of (lhs, rhs). + if (!common_type.has_value()) { + common_type = double_type_dispatcher(lhs, rhs, common_type_functor{}); + } + return common_type; } bool is_supported_operation(data_type out, data_type lhs, data_type rhs, binary_operator op)