Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor reduction logic for fixed-point types #12652

Merged
31 changes: 3 additions & 28 deletions cpp/src/reductions/simple.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -115,37 +115,12 @@ std::unique_ptr<scalar> fixed_point_reduction(
{
using Type = device_storage_type_t<DecimalXX>;

auto dcol = cudf::column_device_view::create(col, stream);
auto simple_op = Op{};

// Cast initial value
std::optional<Type> const initial_value = [&] {
if (init.has_value() && init.value().get().is_valid()) {
using ScalarType = cudf::scalar_type_t<Type>;
return std::optional<Type>(
static_cast<const ScalarType*>(&init.value().get())->value(stream));
} else {
return std::optional<Type>(std::nullopt);
}
}();

auto result = [&] {
if (col.has_nulls()) {
auto f = simple_op.template get_null_replacing_element_transformer<Type>();
auto it = thrust::make_transform_iterator(dcol->pair_begin<Type, true>(), f);
return cudf::reduction::detail::reduce(it, col.size(), simple_op, initial_value, stream, mr);
} else {
auto f = simple_op.template get_element_transformer<Type>();
auto it = thrust::make_transform_iterator(dcol->begin<Type>(), f);
return cudf::reduction::detail::reduce(it, col.size(), simple_op, initial_value, stream, mr);
}
}();
auto result = simple_reduction<Type, Type, Op>(col, init, stream, mr);

auto const scale = [&] {
if (std::is_same_v<Op, cudf::reduction::op::product>) {
auto const valid_count = static_cast<int32_t>(col.size() - col.null_count());
return numeric::scale_type{col.type().scale() *
(valid_count + (initial_value.has_value() ? 1 : 0))};
return numeric::scale_type{col.type().scale() * (valid_count + (init.has_value() ? 1 : 0))};
} else if (std::is_same_v<Op, cudf::reduction::op::sum_of_squares>) {
return numeric::scale_type{col.type().scale() * 2};
}
Expand Down
3 changes: 1 addition & 2 deletions cpp/tests/reductions/reduction_tests.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -294,7 +294,6 @@ TYPED_TEST(SumReductionTest, Sum)
.second);
}

using ReductionTypes = cudf::test::Types<int16_t, int32_t, float, double>;
TYPED_TEST_SUITE(ReductionTest, cudf::test::NumericTypes);

TYPED_TEST(ReductionTest, Product)
Expand Down