diff --git a/cpp/include/cudf/binaryop.hpp b/cpp/include/cudf/binaryop.hpp index 72abefef04f..71ac591b67e 100644 --- a/cpp/include/cudf/binaryop.hpp +++ b/cpp/include/cudf/binaryop.hpp @@ -146,7 +146,7 @@ std::unique_ptr binary_operation( column_view const& lhs, column_view const& rhs, binary_operator op, - data_type output_type, + thrust::optional output_type, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); /** diff --git a/cpp/include/cudf/detail/binaryop.hpp b/cpp/include/cudf/detail/binaryop.hpp index c12482967e1..147e5b71381 100644 --- a/cpp/include/cudf/detail/binaryop.hpp +++ b/cpp/include/cudf/detail/binaryop.hpp @@ -60,7 +60,7 @@ std::unique_ptr binary_operation( column_view const& lhs, column_view const& rhs, binary_operator op, - data_type output_type, + thrust::optional output_type, rmm::cuda_stream_view stream = rmm::cuda_stream_default, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); diff --git a/cpp/src/binaryop/binaryop.cpp b/cpp/src/binaryop/binaryop.cpp index fc697267ca7..c6e8f506924 100644 --- a/cpp/src/binaryop/binaryop.cpp +++ b/cpp/src/binaryop/binaryop.cpp @@ -783,27 +783,36 @@ std::unique_ptr binary_operation(column_view const& lhs, std::unique_ptr binary_operation(column_view const& lhs, column_view const& rhs, binary_operator op, - data_type output_type, + thrust::optional output_type, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { CUDF_EXPECTS(lhs.size() == rhs.size(), "Column sizes don't match"); - if (lhs.type().id() == type_id::STRING and rhs.type().id() == type_id::STRING) - return binops::compiled::binary_operation(lhs, rhs, op, output_type, stream, mr); - if (is_fixed_point(lhs.type()) or is_fixed_point(rhs.type())) { - auto const type = - op == binary_operator::TRUE_DIV ? output_type : thrust::optional{thrust::nullopt}; - return fixed_point_binary_operation(lhs, rhs, op, type, stream, mr); + if (op != binary_operator::TRUE_DIV) { + CUDF_EXPECTS( + not output_type.has_value(), + "Only TRUE_DIV supports specified output_type for fixed_point binary operations. For other " + "fixed_point binary operations, please pass {} or std::nullopt for output_type and the " + "cudf::data_type and numeric::scale_type will be automatically calculated."); + } + + return fixed_point_binary_operation(lhs, rhs, op, output_type, stream, mr); } + CUDF_EXPECTS(output_type.has_value(), "Must specify output_type of column."); + // Use output_type.value() for the rest of the function + + if (lhs.type().id() == type_id::STRING and rhs.type().id() == type_id::STRING) + return binops::compiled::binary_operation(lhs, rhs, op, output_type.value(), stream, mr); + // Check for datatype - CUDF_EXPECTS(is_fixed_width(output_type), "Invalid/Unsupported output datatype"); + CUDF_EXPECTS(is_fixed_width(type), "Invalid/Unsupported output datatype"); CUDF_EXPECTS(is_fixed_width(lhs.type()), "Invalid/Unsupported lhs datatype"); CUDF_EXPECTS(is_fixed_width(rhs.type()), "Invalid/Unsupported rhs datatype"); - auto out = make_fixed_width_column_for_output(lhs, rhs, op, output_type, stream, mr); + auto out = make_fixed_width_column_for_output(lhs, rhs, op, output_type.value(), stream, mr); if (lhs.is_empty() or rhs.is_empty()) return out; @@ -868,7 +877,7 @@ std::unique_ptr binary_operation(column_view const& lhs, std::unique_ptr binary_operation(column_view const& lhs, column_view const& rhs, binary_operator op, - data_type output_type, + thrust::optional output_type, rmm::mr::device_memory_resource* mr) { CUDF_FUNC_RANGE(); diff --git a/cpp/tests/binaryop/binop-integration-test.cpp b/cpp/tests/binaryop/binop-integration-test.cpp index 2d17853a72b..b39f586b8ea 100644 --- a/cpp/tests/binaryop/binop-integration-test.cpp +++ b/cpp/tests/binaryop/binop-integration-test.cpp @@ -22,6 +22,7 @@ #include #include #include +#include #include #include @@ -2123,6 +2124,9 @@ TYPED_TEST(FixedPointTestBothReps, FixedPointBinaryOpDiv2) auto const result = cudf::binary_operation(lhs, rhs, cudf::binary_operator::DIV, {}); CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, result->view()); + auto output_type = cudf::data_type{type_to_id(), scale_type{1}}; + EXPECT_THROW(cudf::binary_operation(lhs, rhs, cudf::binary_operator::DIV, output_type), + cudf::logic_error); } TYPED_TEST(FixedPointTestBothReps, FixedPointBinaryOpDiv3)