Skip to content

Commit

Permalink
Add segmented reduction support for fixed-point types (#12680)
Browse files Browse the repository at this point in the history
Depends on #12573 

Adds additional support for fixed-point types in `cudf::segmented_reduce` for simple aggregations: sum, product, and sum-of-squares.
Reference: #10432

Authors:
  - David Wendt (https://github.com/davidwendt)

Approvers:
  - Mike Wilson (https://github.com/hyperbolic2346)
  - Nghia Truong (https://github.com/ttnghia)
  - Bradley Dice (https://github.com/bdice)

URL: #12680
  • Loading branch information
davidwendt authored Feb 23, 2023
1 parent a96b150 commit f076905
Show file tree
Hide file tree
Showing 6 changed files with 341 additions and 194 deletions.
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@ add_library(
src/reductions/scan/scan_inclusive.cu
src/reductions/segmented/all.cu
src/reductions/segmented/any.cu
src/reductions/segmented/counts.cu
src/reductions/segmented/max.cu
src/reductions/segmented/mean.cu
src/reductions/segmented/min.cu
Expand Down
29 changes: 11 additions & 18 deletions cpp/src/reductions/segmented/compound.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#pragma once

#include "counts.hpp"
#include "update_validity.hpp"

#include <cudf/column/column_factories.hpp>
Expand Down Expand Up @@ -63,34 +64,26 @@ std::unique_ptr<column> compound_segmented_reduction(column_view const& col,
data_type{type_to_id<ResultType>()}, num_segments, mask_state::UNALLOCATED, stream, mr);
auto out_itr = result->mutable_view().template begin<ResultType>();

// Compute valid counts
rmm::device_uvector<size_type> valid_counts(num_segments, stream);
if (col.has_nulls() && (null_handling == null_policy::EXCLUDE)) {
auto valid_fn = [] __device__(auto p) -> size_type { return static_cast<size_type>(p.second); };
auto itr = thrust::make_transform_iterator(d_col->pair_begin<InputType, true>(), valid_fn);
cudf::reduction::detail::segmented_reduce(itr,
offsets.begin(),
offsets.end(),
valid_counts.data(),
thrust::plus<size_type>{},
0,
stream);
} else {
thrust::adjacent_difference(
rmm::exec_policy(stream), offsets.begin() + 1, offsets.end(), valid_counts.begin());
}
// Compute counts
rmm::device_uvector<size_type> counts =
cudf::reduction::detail::segmented_counts(col.null_mask(),
col.has_nulls(),
offsets,
null_handling,
stream,
rmm::mr::get_current_device_resource());

// Run segmented reduction
if (col.has_nulls()) {
auto nrt = compound_op.template get_null_replacing_element_transformer<ResultType>();
auto itr = thrust::make_transform_iterator(d_col->pair_begin<InputType, true>(), nrt);
cudf::reduction::detail::segmented_reduce(
itr, offsets.begin(), offsets.end(), out_itr, compound_op, ddof, valid_counts.data(), stream);
itr, offsets.begin(), offsets.end(), out_itr, compound_op, ddof, counts.data(), stream);
} else {
auto et = compound_op.template get_element_transformer<ResultType>();
auto itr = thrust::make_transform_iterator(d_col->begin<InputType>(), et);
cudf::reduction::detail::segmented_reduce(
itr, offsets.begin(), offsets.end(), out_itr, compound_op, ddof, valid_counts.data(), stream);
itr, offsets.begin(), offsets.end(), out_itr, compound_op, ddof, counts.data(), stream);
}

// Compute the output null mask
Expand Down
54 changes: 54 additions & 0 deletions cpp/src/reductions/segmented/counts.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "counts.hpp"

#include <cudf/detail/null_mask.cuh>

#include <thrust/adjacent_difference.h>

namespace cudf {
namespace reduction {
namespace detail {

rmm::device_uvector<size_type> segmented_counts(bitmask_type const* null_mask,
bool has_nulls,
device_span<size_type const> offsets,
null_policy null_handling,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
auto const num_segments = offsets.size() - 1;

if (has_nulls && (null_handling == null_policy::EXCLUDE)) {
return cudf::detail::segmented_count_bits(null_mask,
offsets.begin(),
offsets.end() - 1,
offsets.begin() + 1,
cudf::detail::count_bits_policy::SET_BITS,
stream,
mr);
}

rmm::device_uvector<size_type> valid_counts(num_segments, stream, mr);
thrust::adjacent_difference(
rmm::exec_policy(stream), offsets.begin() + 1, offsets.end(), valid_counts.begin());
return valid_counts;
}

} // namespace detail
} // namespace reduction
} // namespace cudf
55 changes: 55 additions & 0 deletions cpp/src/reductions/segmented/counts.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cudf/types.hpp>
#include <cudf/utilities/span.hpp>

#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_uvector.hpp>

namespace cudf {
class column_device_view;

namespace reduction {
namespace detail {

/**
* @brief Compute the number of elements per segment
*
* If `null_handling == null_policy::EXCLUDE`, the count for each
* segment omits any null entries. Otherwise, this returns the number
* of elements in each segment.
*
* @param null_mask Null values over which the segment offsets apply
* @param has_nulls True if d_col contains any nulls
* @param offsets Indices to segment boundaries
* @param null_handling How null entries are processed within each segment
* @param stream Used for device memory operations and kernel launches
* @param mr Device memory resource used to allocate the returned column's device memory
* @return The number of elements in each segment
*/
rmm::device_uvector<size_type> segmented_counts(bitmask_type const* null_mask,
bool has_nulls,
device_span<size_type const> offsets,
null_policy null_handling,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr);

} // namespace detail
} // namespace reduction
} // namespace cudf
91 changes: 68 additions & 23 deletions cpp/src/reductions/segmented/simple.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#pragma once

#include "counts.hpp"
#include "update_validity.hpp"

#include <cudf/detail/aggregation/aggregation.hpp>
Expand All @@ -36,6 +37,7 @@
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/iterator/zip_iterator.h>
#include <thrust/reduce.h>

#include <optional>
#include <type_traits>
Expand Down Expand Up @@ -188,7 +190,7 @@ std::unique_ptr<column> string_segmented_reduction(column_view const& col,
}

/**
* @brief Fixed point segmented reduction for 'min', 'max'.
* @brief Specialization for fixed-point segmented reduction
*
* @tparam InputType the input column data-type
* @tparam Op the operator of cudf::reduction::op::
Expand All @@ -200,11 +202,7 @@ std::unique_ptr<column> string_segmented_reduction(column_view const& col,
* @param mr Device memory resource used to allocate the returned column's device memory
* @return Output column in device memory
*/

template <typename InputType,
typename Op,
CUDF_ENABLE_IF(std::is_same_v<Op, cudf::reduction::op::min> ||
std::is_same_v<Op, cudf::reduction::op::max>)>
template <typename InputType, typename Op>
std::unique_ptr<column> fixed_point_segmented_reduction(
column_view const& col,
device_span<size_type const> offsets,
Expand All @@ -214,23 +212,55 @@ std::unique_ptr<column> fixed_point_segmented_reduction(
rmm::mr::device_memory_resource* mr)
{
using RepType = device_storage_type_t<InputType>;
return simple_segmented_reduction<RepType, RepType, Op>(
col, offsets, null_handling, init, stream, mr);
}
auto result =
simple_segmented_reduction<RepType, RepType, Op>(col, offsets, null_handling, init, stream, mr);
auto const scale = [&] {
if constexpr (std::is_same_v<Op, cudf::reduction::op::product>) {
// The product aggregation requires updating the scale of the fixed-point output column.
// The output scale needs to be the maximum count of all segments multiplied by
// the input scale value.
rmm::device_uvector<size_type> const counts =
cudf::reduction::detail::segmented_counts(col.null_mask(),
col.has_nulls(),
offsets,
null_policy::EXCLUDE, // do not count nulls
stream,
rmm::mr::get_current_device_resource());

auto const max_count = thrust::reduce(rmm::exec_policy(stream),
counts.begin(),
counts.end(),
size_type{0},
thrust::maximum<size_type>{});

auto const new_scale = numeric::scale_type{col.type().scale() * max_count};

// adjust values in each segment to match the new scale
auto const d_col = column_device_view::create(col, stream);
thrust::transform(rmm::exec_policy(stream),
d_col->begin<InputType>(),
d_col->end<InputType>(),
d_col->begin<InputType>(),
[new_scale] __device__(auto fp) { return fp.rescaled(new_scale); });
return new_scale;
}

template <typename InputType,
typename Op,
CUDF_ENABLE_IF(!std::is_same_v<Op, cudf::reduction::op::min>() &&
!std::is_same_v<Op, cudf::reduction::op::max>())>
std::unique_ptr<column> fixed_point_segmented_reduction(
column_view const& col,
device_span<size_type const> offsets,
null_policy null_handling,
std::optional<std::reference_wrapper<scalar const>>,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
CUDF_FAIL("Segmented reduction on fixed point column only supports min and max reduction.");
if constexpr (std::is_same_v<Op, cudf::reduction::op::sum_of_squares>) {
return numeric::scale_type{col.type().scale() * 2};
}

return numeric::scale_type{col.type().scale()};
}();

auto const size = result->size(); // get these before
auto const null_count = result->null_count(); // release() is called
auto contents = result->release();

return std::make_unique<column>(data_type{type_to_id<InputType>(), scale},
size,
std::move(*(contents.data.release())),
std::move(*(contents.null_mask.release())),
null_count);
}

/**
Expand Down Expand Up @@ -431,8 +461,23 @@ struct column_type_dispatcher {
return reduce_numeric<ElementType>(col, offsets, output_type, null_handling, init, stream, mr);
}

template <typename ElementType, std::enable_if_t<cudf::is_fixed_point<ElementType>()>* = nullptr>
std::unique_ptr<column> operator()(column_view const& col,
device_span<size_type const> offsets,
data_type const output_type,
null_policy null_handling,
std::optional<std::reference_wrapper<scalar const>> init,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
CUDF_EXPECTS(output_type == col.type(), "Output type must be same as input column type.");
return fixed_point_segmented_reduction<ElementType, Op>(
col, offsets, null_handling, init, stream, mr);
}

template <typename ElementType,
typename std::enable_if_t<not cudf::is_numeric<ElementType>()>* = nullptr>
std::enable_if_t<not cudf::is_numeric<ElementType>() and
not cudf::is_fixed_point<ElementType>()>* = nullptr>
std::unique_ptr<column> operator()(column_view const&,
device_span<size_type const>,
data_type const,
Expand Down
Loading

0 comments on commit f076905

Please sign in to comment.