diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 42a434ba53d..202d10c6929 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -221,6 +221,7 @@ add_library( src/binaryop/compiled/TrueDiv.cu src/binaryop/compiled/binary_ops.cu src/binaryop/compiled/util.cpp + src/binaryop/compiled/util_old.cpp src/labeling/label_bins.cu src/bitmask/null_mask.cu src/bitmask/is_element_valid.cpp diff --git a/cpp/src/binaryop/compiled/binary_ops.hpp b/cpp/src/binaryop/compiled/binary_ops.hpp index d1a40e15326..c0106d2c3c3 100644 --- a/cpp/src/binaryop/compiled/binary_ops.hpp +++ b/cpp/src/binaryop/compiled/binary_ops.hpp @@ -161,6 +161,9 @@ void binary_operation(mutable_column_view& out, * @return common type among @p out, @p lhs, @p rhs. */ std::optional get_common_type(data_type out, data_type lhs, data_type rhs); + +std::optional get_common_type_old(data_type out, data_type lhs, data_type rhs); + /** * @brief Check if input binary operation is supported for the given input and output types. * @@ -172,6 +175,8 @@ std::optional get_common_type(data_type out, data_type lhs, data_type */ bool is_supported_operation(data_type out, data_type lhs, data_type rhs, binary_operator op); +bool is_supported_operation_old(data_type out, data_type lhs, data_type rhs, binary_operator op); + // Defined in individual .cu files. /** * @brief Deploys single type or double type dispatcher that runs binary operation on each element diff --git a/cpp/src/binaryop/compiled/util_old.cpp b/cpp/src/binaryop/compiled/util_old.cpp new file mode 100644 index 00000000000..8106240fcc2 --- /dev/null +++ b/cpp/src/binaryop/compiled/util_old.cpp @@ -0,0 +1,195 @@ +/* + * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "operation.cuh" + +#include +#include +#include +#include + +#include + +namespace cudf::binops::compiled { + +namespace { +/** + * @brief Functor that returns optional common type of 2 or 3 given types. + * + */ +struct common_type_functor { + template + struct nested_common_type_functor { + template + std::optional operator()() + { + // If common_type exists + if constexpr (cudf::has_common_type_v) { + using TypeCommon = typename std::common_type::type; + return data_type{type_to_id()}; + } else if constexpr (cudf::has_common_type_v) { + using TypeCommon = typename std::common_type::type; + // Eg. d=t-t + return data_type{type_to_id()}; + } + + // A compiler bug may cause a compilation error when using empty initializer list to construct + // an std::optional object containing no `data_type` value. Therefore, we should explicitly + // return `std::nullopt` instead. + return std::nullopt; + } + }; + template + std::optional operator()(data_type out) + { + return type_dispatcher(out, nested_common_type_functor{}); + } +}; + +/** + * @brief Functor that return true if BinaryOperator supports given input and output types. + * + * @tparam BinaryOperator binary operator functor + */ +template +struct is_binary_operation_supported { + // For types where Out type is fixed. (eg. comparison types) + template + inline constexpr bool operator()() + { + if constexpr (column_device_view::has_element_accessor() and + column_device_view::has_element_accessor()) { + if constexpr (has_common_type_v) { + using common_t = std::common_type_t; + return std::is_invocable_v; + } else { + return std::is_invocable_v; + } + } else { + return false; + } + } + + template + inline constexpr bool operator()() + { + if constexpr (column_device_view::has_element_accessor() and + column_device_view::has_element_accessor() and + (mutable_column_device_view::has_element_accessor() or + is_fixed_point())) { + if constexpr (has_common_type_v) { + using common_t = std::common_type_t; + if constexpr (std::is_invocable_v) { + using ReturnType = std::invoke_result_t; + return std::is_constructible_v or + (is_fixed_point() and is_fixed_point()); + } + } else { + if constexpr (std::is_invocable_v) { + using ReturnType = std::invoke_result_t; + return std::is_constructible_v; + } + } + } + return false; + } +}; + +struct is_supported_operation_functor { + template + struct nested_support_functor { + template + inline constexpr bool call() + { + return is_binary_operation_supported{} + .template operator()(); + } + template + inline constexpr bool operator()(binary_operator op) + { + switch (op) { + // clang-format off + case binary_operator::ADD: return call(); + case binary_operator::SUB: return call(); + case binary_operator::MUL: return call(); + case binary_operator::DIV: return call(); + case binary_operator::TRUE_DIV: return call(); + case binary_operator::FLOOR_DIV: return call(); + case binary_operator::MOD: return call(); + case binary_operator::PYMOD: return call(); + case binary_operator::POW: return call(); + case binary_operator::BITWISE_AND: return call(); + case binary_operator::BITWISE_OR: return call(); + case binary_operator::BITWISE_XOR: return call(); + case binary_operator::SHIFT_LEFT: return call(); + case binary_operator::SHIFT_RIGHT: return call(); + case binary_operator::SHIFT_RIGHT_UNSIGNED: return call(); + case binary_operator::LOG_BASE: return call(); + case binary_operator::ATAN2: return call(); + case binary_operator::PMOD: return call(); + case binary_operator::NULL_MAX: return call(); + case binary_operator::NULL_MIN: return call(); + /* + case binary_operator::GENERIC_BINARY: // defined in jit only. + */ + default: return false; + // clang-format on + } + } + }; + + template + inline constexpr bool bool_op(data_type out) + { + return out.id() == type_id::BOOL8 and + is_binary_operation_supported{}.template operator()(); + } + template + inline constexpr bool operator()(data_type out, binary_operator op) + { + switch (op) { + // output type should be bool type. + case binary_operator::LOGICAL_AND: return bool_op(out); + case binary_operator::LOGICAL_OR: return bool_op(out); + case binary_operator::EQUAL: return bool_op(out); + case binary_operator::NOT_EQUAL: return bool_op(out); + case binary_operator::LESS: return bool_op(out); + case binary_operator::GREATER: return bool_op(out); + case binary_operator::LESS_EQUAL: return bool_op(out); + case binary_operator::GREATER_EQUAL: return bool_op(out); + case binary_operator::NULL_EQUALS: return bool_op(out); + case binary_operator::NULL_LOGICAL_AND: + return bool_op(out); + case binary_operator::NULL_LOGICAL_OR: + return bool_op(out); + default: return type_dispatcher(out, nested_support_functor{}, op); + } + return false; + } +}; + +} // namespace + +std::optional get_common_type_old(data_type out, data_type lhs, data_type rhs) +{ + return double_type_dispatcher(lhs, rhs, common_type_functor{}, out); +} + +bool is_supported_operation_old(data_type out, data_type lhs, data_type rhs, binary_operator op) +{ + return double_type_dispatcher(lhs, rhs, is_supported_operation_functor{}, out, op); +} +} // namespace cudf::binops::compiled diff --git a/cpp/tests/binaryop/binop-verify-input-test.cpp b/cpp/tests/binaryop/binop-verify-input-test.cpp index 167fbc22bde..ac0ed3e081c 100644 --- a/cpp/tests/binaryop/binop-verify-input-test.cpp +++ b/cpp/tests/binaryop/binop-verify-input-test.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Copyright 2018-2019 BlazingDB, Inc. * Copyright 2018 Christian Noboa Mardini @@ -19,8 +19,12 @@ #include +#include + #include +#include + namespace cudf { namespace test { namespace binop { @@ -54,6 +58,158 @@ TEST_F(BinopVerifyInputTest, Vector_Vector_ErrorSecondOperandVectorZeroSize) cudf::logic_error); } +struct BinopTypeTest : public BinaryOperationTest { +}; +TEST_F(BinopTypeTest, GetCommonTypeTest) +{ + auto const type_ids = std::vector{ + // cudf::type_id::EMPTY, ///< Always null with no underlying data + cudf::type_id::INT8, ///< 1 byte signed integer + cudf::type_id::INT16, ///< 2 byte signed integer + cudf::type_id::INT32, ///< 4 byte signed integer + cudf::type_id::INT64, ///< 8 byte signed integer + cudf::type_id::UINT8, ///< 1 byte unsigned integer + cudf::type_id::UINT16, ///< 2 byte unsigned integer + cudf::type_id::UINT32, ///< 4 byte unsigned integer + cudf::type_id::UINT64, ///< 8 byte unsigned integer + cudf::type_id::FLOAT32, ///< 4 byte floating point + cudf::type_id::FLOAT64, ///< 8 byte floating point + cudf::type_id::BOOL8, ///< Boolean using one byte per value, 0 == false, else true + cudf::type_id::TIMESTAMP_DAYS, ///< point in time in days since Unix Epoch in int32 + cudf::type_id::TIMESTAMP_SECONDS, ///< point in time in seconds since Unix Epoch in int64 + cudf::type_id::TIMESTAMP_MILLISECONDS, ///< point in time in milliseconds since Unix Epoch in + ///< int64 + cudf::type_id::TIMESTAMP_MICROSECONDS, ///< point in time in microseconds since Unix Epoch in + ///< int64 + cudf::type_id::TIMESTAMP_NANOSECONDS, ///< point in time in nanoseconds since Unix Epoch in + ///< int64 + cudf::type_id::DURATION_DAYS, ///< time interval of days in int32 + cudf::type_id::DURATION_SECONDS, ///< time interval of seconds in int64 + cudf::type_id::DURATION_MILLISECONDS, ///< time interval of milliseconds in int64 + cudf::type_id::DURATION_MICROSECONDS, ///< time interval of microseconds in int64 + cudf::type_id::DURATION_NANOSECONDS, ///< time interval of nanoseconds in int64 + cudf::type_id::DICTIONARY32, ///< Dictionary type using int32 indices + cudf::type_id::STRING, ///< String elements + cudf::type_id::LIST, ///< List elements + cudf::type_id::DECIMAL32, ///< Fixed-point type with int32_t + cudf::type_id::DECIMAL64, ///< Fixed-point type with int64_t + cudf::type_id::DECIMAL128, ///< Fixed-point type with __int128_t + cudf::type_id::STRUCT ///< Struct elements + }; + for (auto const& t1 : type_ids) { + for (auto const& t2 : type_ids) { + for (auto const& t3 : type_ids) { + auto const d1 = cudf::data_type(t1); + auto const d2 = cudf::data_type(t2); + auto const d3 = cudf::data_type(t3); + auto const old = cudf::binops::compiled::get_common_type_old(d1, d2, d3); + (void)old; + auto const new_ = cudf::binops::compiled::get_common_type(d1, d2, d3); + (void)new_; + std::cerr << static_cast(t1) << ", " << static_cast(t2) << ", " + << static_cast(t3) << std::endl; + EXPECT_EQ(old, new_); + } + } + } +} + +TEST_F(BinopTypeTest, IsSupportedOperationTest) +{ + auto const binops = std::vector{ + cudf::binary_operator::ADD, ///< operator + + cudf::binary_operator::SUB, ///< operator - + cudf::binary_operator::MUL, ///< operator * + cudf::binary_operator::DIV, ///< operator / using common type of lhs and rhs + cudf::binary_operator::TRUE_DIV, ///< operator / after promoting type to floating point + cudf::binary_operator::FLOOR_DIV, ///< operator / after promoting to 64 bit floating point and + ///< then + cudf::binary_operator::MOD, ///< operator % + cudf::binary_operator::PMOD, ///< positive modulo operator + cudf::binary_operator::PYMOD, ///< operator % but following Python's sign rules for negatives + cudf::binary_operator::POW, ///< lhs ^ rhs + cudf::binary_operator::LOG_BASE, ///< logarithm to the base + cudf::binary_operator::ATAN2, ///< 2-argument arctangent + cudf::binary_operator::SHIFT_LEFT, ///< operator << + cudf::binary_operator::SHIFT_RIGHT, ///< operator >> + cudf::binary_operator::SHIFT_RIGHT_UNSIGNED, ///< operator >>> (from Java) + cudf::binary_operator::BITWISE_AND, ///< operator & + cudf::binary_operator::BITWISE_OR, ///< operator | + cudf::binary_operator::BITWISE_XOR, ///< operator ^ + cudf::binary_operator::LOGICAL_AND, ///< operator && + cudf::binary_operator::LOGICAL_OR, ///< operator || + cudf::binary_operator::EQUAL, ///< operator == + cudf::binary_operator::NOT_EQUAL, ///< operator != + cudf::binary_operator::LESS, ///< operator < + cudf::binary_operator::GREATER, ///< operator > + cudf::binary_operator::LESS_EQUAL, ///< operator <= + cudf::binary_operator::GREATER_EQUAL, ///< operator >= + cudf::binary_operator::NULL_EQUALS, ///< Returns true when both operands are null; false when + ///< one is null; the + cudf::binary_operator::NULL_MAX, ///< Returns max of operands when both are non-null; returns + ///< the non-null + cudf::binary_operator::NULL_MIN, ///< Returns min of operands when both are non-null; returns + ///< the non-null + cudf::binary_operator::GENERIC_BINARY, ///< generic binary operator to be generated with input + cudf::binary_operator::NULL_LOGICAL_AND, ///< operator && with Spark rules: (null, null) is + ///< null, (null, true) is + cudf::binary_operator::NULL_LOGICAL_OR, ///< operator || with Spark rules: (null, null) is + ///< null, (null, true) is true, + cudf::binary_operator::INVALID_BINARY ///< invalid operation + }; + + auto const type_ids = std::vector{ + // cudf::type_id::EMPTY, ///< Always null with no underlying data + cudf::type_id::INT8, ///< 1 byte signed integer + cudf::type_id::INT16, ///< 2 byte signed integer + cudf::type_id::INT32, ///< 4 byte signed integer + cudf::type_id::INT64, ///< 8 byte signed integer + cudf::type_id::UINT8, ///< 1 byte unsigned integer + cudf::type_id::UINT16, ///< 2 byte unsigned integer + cudf::type_id::UINT32, ///< 4 byte unsigned integer + cudf::type_id::UINT64, ///< 8 byte unsigned integer + cudf::type_id::FLOAT32, ///< 4 byte floating point + cudf::type_id::FLOAT64, ///< 8 byte floating point + cudf::type_id::BOOL8, ///< Boolean using one byte per value, 0 == false, else true + cudf::type_id::TIMESTAMP_DAYS, ///< point in time in days since Unix Epoch in int32 + cudf::type_id::TIMESTAMP_SECONDS, ///< point in time in seconds since Unix Epoch in int64 + cudf::type_id::TIMESTAMP_MILLISECONDS, ///< point in time in milliseconds since Unix Epoch in + ///< int64 + cudf::type_id::TIMESTAMP_MICROSECONDS, ///< point in time in microseconds since Unix Epoch in + ///< int64 + cudf::type_id::TIMESTAMP_NANOSECONDS, ///< point in time in nanoseconds since Unix Epoch in + ///< int64 + cudf::type_id::DURATION_DAYS, ///< time interval of days in int32 + cudf::type_id::DURATION_SECONDS, ///< time interval of seconds in int64 + cudf::type_id::DURATION_MILLISECONDS, ///< time interval of milliseconds in int64 + cudf::type_id::DURATION_MICROSECONDS, ///< time interval of microseconds in int64 + cudf::type_id::DURATION_NANOSECONDS, ///< time interval of nanoseconds in int64 + cudf::type_id::DICTIONARY32, ///< Dictionary type using int32 indices + cudf::type_id::STRING, ///< String elements + cudf::type_id::LIST, ///< List elements + cudf::type_id::DECIMAL32, ///< Fixed-point type with int32_t + cudf::type_id::DECIMAL64, ///< Fixed-point type with int64_t + cudf::type_id::DECIMAL128, ///< Fixed-point type with __int128_t + cudf::type_id::STRUCT ///< Struct elements + }; + for (auto const& op : binops) { + for (auto const& t1 : type_ids) { + for (auto const& t2 : type_ids) { + for (auto const& t3 : type_ids) { + auto const d1 = cudf::data_type(t1); + auto const d2 = cudf::data_type(t2); + auto const d3 = cudf::data_type(t3); + auto const old = cudf::binops::compiled::is_supported_operation_old(d1, d2, d3, op); + auto const new_ = cudf::binops::compiled::is_supported_operation(d1, d2, d3, op); + std::cerr << static_cast(t1) << ", " << static_cast(t2) << ", " + << static_cast(t3) << ", " << static_cast(op) << std::endl; + EXPECT_EQ(old, new_); + } + } + } + } +} + } // namespace binop } // namespace test } // namespace cudf