diff --git a/cpp/src/unary/cast_ops.cu b/cpp/src/unary/cast_ops.cu index e852b00796a..131fde11cf8 100644 --- a/cpp/src/unary/cast_ops.cu +++ b/cpp/src/unary/cast_ops.cu @@ -305,28 +305,39 @@ struct dispatch_unary_cast_to { rmm::mr::device_memory_resource* mr) { using namespace numeric; - - auto const size = input.size(); - auto temporary = - std::make_unique(cudf::data_type{type.id(), input.type().scale()}, - size, - rmm::device_buffer{size * cudf::size_of(type), stream}, - copy_bitmask(input, stream), - input.null_count()); - using SourceDeviceT = device_storage_type_t; using TargetDeviceT = device_storage_type_t; - mutable_column_view output_mutable = *temporary; - - thrust::transform(rmm::exec_policy(stream), - input.begin(), - input.end(), - output_mutable.begin(), - device_cast{}); - - // clearly there is a more efficient way to do this, can optimize in the future - return rescale(*temporary, numeric::scale_type{type.scale()}, stream, mr); + auto casted = [&]() { + auto const size = input.size(); + auto output = std::make_unique(cudf::data_type{type.id(), input.type().scale()}, + size, + rmm::device_buffer{size * cudf::size_of(type), stream}, + copy_bitmask(input, stream), + input.null_count()); + + mutable_column_view output_mutable = *output; + + thrust::transform(rmm::exec_policy(stream), + input.begin(), + input.end(), + output_mutable.begin(), + device_cast{}); + + return output; + }; + + if (input.type().scale() == type.scale()) return casted(); + + if constexpr (sizeof(SourceDeviceT) < sizeof(TargetDeviceT)) { + // device_cast BEFORE rescale when SourceDeviceT is < TargetDeviceT + auto temporary = casted(); + return detail::rescale(*temporary, scale_type{type.scale()}, stream, mr); + } else { + // device_cast AFTER rescale when SourceDeviceT is > TargetDeviceT to avoid overflow + auto temporary = detail::rescale(input, scale_type{type.scale()}, stream, mr); + return detail::cast(*temporary, type, stream, mr); + } } template view()); } + +TEST_F(FixedPointTestSingleType, Int32ToInt64Convert) +{ + using namespace numeric; + using fp_wrapperA = cudf::test::fixed_point_column_wrapper; + using fp_wrapperB = cudf::test::fixed_point_column_wrapper; + + auto const input = fp_wrapperB{{141230900000L}, scale_type{-10}}; + auto const expected = fp_wrapperA{{14123}, scale_type{-3}}; + auto const result = cudf::cast(input, make_fixed_point_data_type(-3)); + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, result->view()); +}