-
Notifications
You must be signed in to change notification settings - Fork 917
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move rank scan implementations from scan_inclusive.cu to rank_scan.cu (…
…#9351) This change was mainly to improve the compile time for `reductions/scan/scan_inclusive.cu` by refactoring out the rank-scan functions into a separate file `rank.cu`. Although the overall compile time improvement for `scan_inclusive.cu` is only 25%, the source code is better organized with this change. The code function has changed. The detail `inclusive_rank_scan` and `inclusive_dense_rank_scan` declarations were moved from `src/reductions/scan/scan.cuh` to `include/cudf/detail/scan.hpp` and dispatching of the RANK and DENSE_RANK aggregation is done in `scan.cpp` instead of handled by `scan_inclusive.cu` and also `scan_exclusive.cu` (which just throws an exception anyway). Authors: - David Wendt (https://github.com/davidwendt) Approvers: - Jake Hemstad (https://github.com/jrhemstad) - Vukasin Milovanovic (https://github.com/vuule) URL: #9351
- Loading branch information
1 parent
fb18491
commit 122da20
Showing
8 changed files
with
172 additions
and
119 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
/* | ||
* Copyright (c) 2021, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include <structs/utilities.hpp> | ||
|
||
#include <cudf/column/column_device_view.cuh> | ||
#include <cudf/column/column_factories.hpp> | ||
#include <cudf/detail/utilities/device_operators.cuh> | ||
#include <cudf/table/row_operators.cuh> | ||
|
||
#include <rmm/cuda_stream_view.hpp> | ||
#include <rmm/exec_policy.hpp> | ||
|
||
#include <thrust/scan.h> | ||
#include <thrust/tabulate.h> | ||
|
||
namespace cudf { | ||
namespace detail { | ||
namespace { | ||
|
||
/** | ||
* @brief generate row ranks or dense ranks using a row comparison then scan the results | ||
* | ||
* @tparam has_nulls if the order_by column has nulls | ||
* @tparam value_resolver flag value resolver with boolean first and row number arguments | ||
* @tparam scan_operator scan function ran on the flag values | ||
* @param order_by input column to generate ranks for | ||
* @param resolver flag value resolver | ||
* @param scan_op scan operation ran on the flag results | ||
* @param stream CUDA stream used for device memory operations and kernel launches | ||
* @param mr Device memory resource used to allocate the returned column's device memory | ||
* @return std::unique_ptr<column> rank values | ||
*/ | ||
template <bool has_nulls, typename value_resolver, typename scan_operator> | ||
std::unique_ptr<column> rank_generator(column_view const& order_by, | ||
value_resolver resolver, | ||
scan_operator scan_op, | ||
rmm::cuda_stream_view stream, | ||
rmm::mr::device_memory_resource* mr) | ||
{ | ||
auto const superimposed = structs::detail::superimpose_parent_nulls(order_by, stream, mr); | ||
table_view const order_table{{std::get<0>(superimposed)}}; | ||
auto const flattener = cudf::structs::detail::flatten_nested_columns( | ||
order_table, {}, {}, structs::detail::column_nullability::MATCH_INCOMING); | ||
auto const d_flat_order = table_device_view::create(std::get<0>(flattener), stream); | ||
row_equality_comparator<has_nulls> comparator(*d_flat_order, *d_flat_order, true); | ||
auto ranks = make_fixed_width_column(data_type{type_to_id<size_type>()}, | ||
order_table.num_rows(), | ||
mask_state::UNALLOCATED, | ||
stream, | ||
mr); | ||
auto mutable_ranks = ranks->mutable_view(); | ||
|
||
thrust::tabulate(rmm::exec_policy(stream), | ||
mutable_ranks.begin<size_type>(), | ||
mutable_ranks.end<size_type>(), | ||
[comparator, resolver] __device__(size_type row_index) { | ||
return resolver(row_index == 0 || !comparator(row_index, row_index - 1), | ||
row_index); | ||
}); | ||
|
||
thrust::inclusive_scan(rmm::exec_policy(stream), | ||
mutable_ranks.begin<size_type>(), | ||
mutable_ranks.end<size_type>(), | ||
mutable_ranks.begin<size_type>(), | ||
scan_op); | ||
return ranks; | ||
} | ||
|
||
} // namespace | ||
|
||
std::unique_ptr<column> inclusive_dense_rank_scan(column_view const& order_by, | ||
rmm::cuda_stream_view stream, | ||
rmm::mr::device_memory_resource* mr) | ||
{ | ||
CUDF_EXPECTS(!cudf::structs::detail::is_or_has_nested_lists(order_by), | ||
"Unsupported list type in dense_rank scan."); | ||
if (has_nested_nulls(table_view{{order_by}})) { | ||
return rank_generator<true>( | ||
order_by, | ||
[] __device__(bool equality, auto row_index) { return equality; }, | ||
DeviceSum{}, | ||
stream, | ||
mr); | ||
} | ||
return rank_generator<false>( | ||
order_by, | ||
[] __device__(bool equality, auto row_index) { return equality; }, | ||
DeviceSum{}, | ||
stream, | ||
mr); | ||
} | ||
|
||
std::unique_ptr<column> inclusive_rank_scan(column_view const& order_by, | ||
rmm::cuda_stream_view stream, | ||
rmm::mr::device_memory_resource* mr) | ||
{ | ||
CUDF_EXPECTS(!cudf::structs::detail::is_or_has_nested_lists(order_by), | ||
"Unsupported list type in rank scan."); | ||
if (has_nested_nulls(table_view{{order_by}})) { | ||
return rank_generator<true>( | ||
order_by, | ||
[] __device__(bool equality, auto row_index) { return equality ? row_index + 1 : 0; }, | ||
DeviceMax{}, | ||
stream, | ||
mr); | ||
} | ||
return rank_generator<false>( | ||
order_by, | ||
[] __device__(bool equality, auto row_index) { return equality ? row_index + 1 : 0; }, | ||
DeviceMax{}, | ||
stream, | ||
mr); | ||
} | ||
|
||
} // namespace detail | ||
} // namespace cudf |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters