-
Notifications
You must be signed in to change notification settings - Fork 915
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add segmented reduction support for fixed-point types (#12680)
Depends on #12573 Adds additional support for fixed-point types in `cudf::segmented_reduce` for simple aggregations: sum, product, and sum-of-squares. Reference: #10432 Authors: - David Wendt (https://github.com/davidwendt) Approvers: - Mike Wilson (https://github.com/hyperbolic2346) - Nghia Truong (https://github.com/ttnghia) - Bradley Dice (https://github.com/bdice) URL: #12680
- Loading branch information
1 parent
a96b150
commit f076905
Showing
6 changed files
with
341 additions
and
194 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
/* | ||
* Copyright (c) 2023, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include "counts.hpp" | ||
|
||
#include <cudf/detail/null_mask.cuh> | ||
|
||
#include <thrust/adjacent_difference.h> | ||
|
||
namespace cudf { | ||
namespace reduction { | ||
namespace detail { | ||
|
||
rmm::device_uvector<size_type> segmented_counts(bitmask_type const* null_mask, | ||
bool has_nulls, | ||
device_span<size_type const> offsets, | ||
null_policy null_handling, | ||
rmm::cuda_stream_view stream, | ||
rmm::mr::device_memory_resource* mr) | ||
{ | ||
auto const num_segments = offsets.size() - 1; | ||
|
||
if (has_nulls && (null_handling == null_policy::EXCLUDE)) { | ||
return cudf::detail::segmented_count_bits(null_mask, | ||
offsets.begin(), | ||
offsets.end() - 1, | ||
offsets.begin() + 1, | ||
cudf::detail::count_bits_policy::SET_BITS, | ||
stream, | ||
mr); | ||
} | ||
|
||
rmm::device_uvector<size_type> valid_counts(num_segments, stream, mr); | ||
thrust::adjacent_difference( | ||
rmm::exec_policy(stream), offsets.begin() + 1, offsets.end(), valid_counts.begin()); | ||
return valid_counts; | ||
} | ||
|
||
} // namespace detail | ||
} // namespace reduction | ||
} // namespace cudf |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
/* | ||
* Copyright (c) 2023, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <cudf/types.hpp> | ||
#include <cudf/utilities/span.hpp> | ||
|
||
#include <rmm/cuda_stream_view.hpp> | ||
#include <rmm/device_uvector.hpp> | ||
|
||
namespace cudf { | ||
class column_device_view; | ||
|
||
namespace reduction { | ||
namespace detail { | ||
|
||
/** | ||
* @brief Compute the number of elements per segment | ||
* | ||
* If `null_handling == null_policy::EXCLUDE`, the count for each | ||
* segment omits any null entries. Otherwise, this returns the number | ||
* of elements in each segment. | ||
* | ||
* @param null_mask Null values over which the segment offsets apply | ||
* @param has_nulls True if d_col contains any nulls | ||
* @param offsets Indices to segment boundaries | ||
* @param null_handling How null entries are processed within each segment | ||
* @param stream Used for device memory operations and kernel launches | ||
* @param mr Device memory resource used to allocate the returned column's device memory | ||
* @return The number of elements in each segment | ||
*/ | ||
rmm::device_uvector<size_type> segmented_counts(bitmask_type const* null_mask, | ||
bool has_nulls, | ||
device_span<size_type const> offsets, | ||
null_policy null_handling, | ||
rmm::cuda_stream_view stream, | ||
rmm::mr::device_memory_resource* mr); | ||
|
||
} // namespace detail | ||
} // namespace reduction | ||
} // namespace cudf |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.