Skip to content

Commit

Permalink
Make binary operators work between fixed-point and floating args (rap…
Browse files Browse the repository at this point in the history
…idsai#16116)

Some of the binary operators in cuDF don't work between fixed_point and floating-point numbers after [this earlier PR](rapidsai#15438) removed the ability to construct and implicitly cast fixed_point numbers from floating point numbers. This PR restores that functionality by detecting and performing the necessary explicit casts, and adds tests for the supported operators. 

Note that the `binary_op_has_common_type` code is modeled after `has_common_type` found in traits.hpp. 

This closes [issue 16090](rapidsai#16090)

Authors:
  - Paul Mattione (https://github.com/pmattione-nvidia)

Approvers:
  - Jayjeet Chakraborty (https://github.com/JayjeetAtGithub)
  - Karthikeyan (https://github.com/karthikeyann)

URL: rapidsai#16116
  • Loading branch information
pmattione-nvidia authored Jun 28, 2024
1 parent 224ac5b commit 673d766
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 9 deletions.
50 changes: 50 additions & 0 deletions cpp/include/cudf/binaryop.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,56 @@ enum class binary_operator : int32_t {
///< (null, false) is null, and (valid, valid) == LOGICAL_OR(valid, valid)
INVALID_BINARY ///< invalid operation
};

/// Binary operation common type default
template <typename L, typename R, typename = void>
struct binary_op_common_type {};

/// Binary operation common type specialization
template <typename L, typename R>
struct binary_op_common_type<L, R, std::enable_if_t<has_common_type_v<L, R>>> {
/// The common type of the template parameters
using type = std::common_type_t<L, R>;
};

/// Binary operation common type specialization
template <typename L, typename R>
struct binary_op_common_type<
L,
R,
std::enable_if_t<is_fixed_point<L>() && cuda::std::is_floating_point_v<R>>> {
/// The common type of the template parameters
using type = L;
};

/// Binary operation common type specialization
template <typename L, typename R>
struct binary_op_common_type<
L,
R,
std::enable_if_t<is_fixed_point<R>() && cuda::std::is_floating_point_v<L>>> {
/// The common type of the template parameters
using type = R;
};

/// Binary operation common type helper
template <typename L, typename R>
using binary_op_common_type_t = typename binary_op_common_type<L, R>::type;

namespace detail {
template <typename AlwaysVoid, typename L, typename R>
struct binary_op_has_common_type_impl : std::false_type {};

template <typename L, typename R>
struct binary_op_has_common_type_impl<std::void_t<binary_op_common_type_t<L, R>>, L, R>
: std::true_type {};
} // namespace detail

/// Checks if binary operation types have a common type
template <typename L, typename R>
constexpr inline bool binary_op_has_common_type_v =
detail::binary_op_has_common_type_impl<void, L, R>::value;

/**
* @brief Performs a binary operation between a scalar and a column.
*
Expand Down
14 changes: 11 additions & 3 deletions cpp/src/binaryop/compiled/binary_ops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,16 @@ struct type_casted_accessor {
column_device_view const& col,
bool is_scalar) const
{
if constexpr (column_device_view::has_element_accessor<Element>() and
std::is_convertible_v<Element, CastType>)
return static_cast<CastType>(col.element<Element>(is_scalar ? 0 : i));
if constexpr (column_device_view::has_element_accessor<Element>()) {
auto const element = col.element<Element>(is_scalar ? 0 : i);
if constexpr (std::is_convertible_v<Element, CastType>) {
return static_cast<CastType>(element);
} else if constexpr (is_fixed_point<Element>() && cuda::std::is_floating_point_v<CastType>) {
return convert_fixed_to_floating<CastType>(element);
} else if constexpr (is_fixed_point<CastType>() && cuda::std::is_floating_point_v<Element>) {
return convert_floating_to_fixed<CastType>(element, numeric::scale_type{0});
}
}
return {};
}
};
Expand Down Expand Up @@ -159,6 +166,7 @@ struct ops2_wrapper {
TypeRhs y = rhs.element<TypeRhs>(is_rhs_scalar ? 0 : i);
auto result = [&]() {
if constexpr (std::is_same_v<BinaryOperator, ops::NullEquals> or
std::is_same_v<BinaryOperator, ops::NullNotEquals> or
std::is_same_v<BinaryOperator, ops::NullLogicalAnd> or
std::is_same_v<BinaryOperator, ops::NullLogicalOr> or
std::is_same_v<BinaryOperator, ops::NullMax> or
Expand Down
12 changes: 6 additions & 6 deletions cpp/src/binaryop/compiled/util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ struct common_type_functor {
template <typename TypeLhs, typename TypeRhs>
std::optional<data_type> operator()() const
{
if constexpr (cudf::has_common_type_v<TypeLhs, TypeRhs>) {
using TypeCommon = std::common_type_t<TypeLhs, TypeRhs>;
if constexpr (binary_op_has_common_type_v<TypeLhs, TypeRhs>) {
using TypeCommon = binary_op_common_type_t<TypeLhs, TypeRhs>;
return data_type{type_to_id<TypeCommon>()};
}

Expand Down Expand Up @@ -85,8 +85,8 @@ struct is_binary_operation_supported {
{
if constexpr (column_device_view::has_element_accessor<TypeLhs>() and
column_device_view::has_element_accessor<TypeRhs>()) {
if constexpr (has_common_type_v<TypeLhs, TypeRhs>) {
using common_t = std::common_type_t<TypeLhs, TypeRhs>;
if constexpr (binary_op_has_common_type_v<TypeLhs, TypeRhs>) {
using common_t = binary_op_common_type_t<TypeLhs, TypeRhs>;
return std::is_invocable_v<BinaryOperator, common_t, common_t>;
} else {
return std::is_invocable_v<BinaryOperator, TypeLhs, TypeRhs>;
Expand All @@ -102,8 +102,8 @@ struct is_binary_operation_supported {
if constexpr (column_device_view::has_element_accessor<TypeLhs>() and
column_device_view::has_element_accessor<TypeRhs>()) {
if (has_mutable_element_accessor(out_type) or is_fixed_point(out_type)) {
if constexpr (has_common_type_v<TypeLhs, TypeRhs>) {
using common_t = std::common_type_t<TypeLhs, TypeRhs>;
if constexpr (binary_op_has_common_type_v<TypeLhs, TypeRhs>) {
using common_t = binary_op_common_type_t<TypeLhs, TypeRhs>;
if constexpr (std::is_invocable_v<BinaryOperator, common_t, common_t>) {
using ReturnType = std::invoke_result_t<BinaryOperator, common_t, common_t>;
return is_constructible<ReturnType>(out_type) or
Expand Down
58 changes: 58 additions & 0 deletions cpp/tests/binaryop/binop-compiled-fixed_point-test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -843,3 +843,61 @@ TYPED_TEST(FixedPointTest_64_128_Reps, FixedPoint_64_128_ComparisonTests)
CUDF_TEST_EXPECT_COLUMNS_EQUAL(h->view(), falses);
}
}

template <typename ResultType>
void test_fixed_floating(cudf::binary_operator op,
double floating_value,
int decimal_value,
int decimal_scale,
ResultType expected)
{
auto const scale = numeric::scale_type{decimal_scale};
auto const result_type = cudf::data_type(cudf::type_to_id<ResultType>());
auto const nullable =
(op == cudf::binary_operator::NULL_EQUALS || op == cudf::binary_operator::NULL_NOT_EQUALS ||
op == cudf::binary_operator::NULL_MIN || op == cudf::binary_operator::NULL_MAX);

cudf::test::fixed_width_column_wrapper<double> floating_col({floating_value});
cudf::test::fixed_point_column_wrapper<int> decimal_col({decimal_value}, scale);

auto result = binary_operation(floating_col, decimal_col, op, result_type);

if constexpr (cudf::is_fixed_point<ResultType>()) {
using wrapper_type = cudf::test::fixed_point_column_wrapper<typename ResultType::rep>;
auto const expected_col = nullable ? wrapper_type({expected.value()}, {true}, expected.scale())
: wrapper_type({expected.value()}, expected.scale());
CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_col, *result.get());
} else {
using wrapper_type = cudf::test::fixed_width_column_wrapper<ResultType>;
auto const expected_col =
nullable ? wrapper_type({expected}, {true}) : wrapper_type({expected});
CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_col, *result.get());
}
}

TYPED_TEST(FixedPointCompiledTest, FixedPointWithFloating)
{
using namespace numeric;

// BOOLEAN
test_fixed_floating(cudf::binary_operator::EQUAL, 1.0, 10, -1, true);
test_fixed_floating(cudf::binary_operator::NOT_EQUAL, 1.0, 10, -1, false);
test_fixed_floating(cudf::binary_operator::LESS, 2.0, 10, -1, false);
test_fixed_floating(cudf::binary_operator::GREATER, 2.0, 10, -1, true);
test_fixed_floating(cudf::binary_operator::LESS_EQUAL, 2.0, 20, -1, true);
test_fixed_floating(cudf::binary_operator::GREATER_EQUAL, 2.0, 30, -1, false);
test_fixed_floating(cudf::binary_operator::NULL_EQUALS, 1.0, 10, -1, true);
test_fixed_floating(cudf::binary_operator::NULL_NOT_EQUALS, 1.0, 10, -1, false);

// PRIMARY ARITHMETIC
auto const decimal_result = numeric::decimal32(4, numeric::scale_type{0});
test_fixed_floating(cudf::binary_operator::ADD, 1.0, 30, -1, decimal_result);
test_fixed_floating(cudf::binary_operator::SUB, 6.0, 20, -1, decimal_result);
test_fixed_floating(cudf::binary_operator::MUL, 2.0, 20, -1, decimal_result);
test_fixed_floating(cudf::binary_operator::DIV, 8.0, 2, 0, decimal_result);
test_fixed_floating(cudf::binary_operator::MOD, 9.0, 50, -1, decimal_result);

// OTHER ARITHMETIC
test_fixed_floating(cudf::binary_operator::NULL_MAX, 4.0, 20, -1, decimal_result);
test_fixed_floating(cudf::binary_operator::NULL_MIN, 4.0, 200, -1, decimal_result);
}

0 comments on commit 673d766

Please sign in to comment.