-
Notifications
You must be signed in to change notification settings - Fork 916
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add
groupby::replace_nulls(replace_policy)
api (#7118)
Part 1 of #4896, follow up of #6907 This PR provides a groupby version of the `replace_nulls(replace_policy)` function. A regular `replace_nulls(replace_policy)` operation updates the nulls with the first non-null value that precedes/follows the null. The groupby version is similar, with an exception that the non-null value look-up is bounded by groups. Here is an example to illustrate the API input/output behavior: ```python #Input: keys = [2, 1, 2, 1] values = [3, 4, NULL, NULL] #Output, group order is not guaranteed: sorted_keys = [1, 1, 2, 2] result = [4, 4, 3, 3] ``` Authors: - Michael Wang (https://github.com/isVoid) Approvers: - AJ Schmidt (https://github.com/ajschmidt8) - https://github.com/nvdbaranec - https://github.com/brandon-b-miller - Jake Hemstad (https://github.com/jrhemstad) URL: #7118
- Loading branch information
Showing
12 changed files
with
660 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
/* | ||
* Copyright (c) 2021, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <cudf/column/column_view.hpp> | ||
#include <cudf/replace.hpp> | ||
#include <cudf/types.hpp> | ||
#include <cudf/utilities/span.hpp> | ||
|
||
#include <rmm/exec_policy.hpp> | ||
namespace cudf { | ||
namespace groupby { | ||
namespace detail { | ||
|
||
/** | ||
* @brief Internal API to replace nulls with preceding/following non-null values in @p value | ||
* | ||
* @param[in] grouped_value A column whose null values will be replaced. | ||
* @param[in] group_labels Group labels for @p grouped_value, corresponding to group keys. | ||
* @param[in] replace_policy Specify the position of replacement values relative to null values. | ||
* @param stream CUDA stream used for device memory operations and kernel launches. | ||
* @param[in] mr Device memory resource used to allocate device memory of the returned column. | ||
*/ | ||
std::unique_ptr<column> group_replace_nulls( | ||
cudf::column_view const& grouped_value, | ||
device_span<size_type const> group_labels, | ||
cudf::replace_policy replace_policy, | ||
rmm::cuda_stream_view stream = rmm::cuda_stream_default, | ||
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); | ||
|
||
} // namespace detail | ||
} // namespace groupby | ||
} // namespace cudf |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
/* | ||
* Copyright (c) 2021, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
#pragma once | ||
|
||
#include <cudf/types.hpp> | ||
|
||
#include <thrust/functional.h> | ||
|
||
namespace cudf { | ||
namespace detail { | ||
|
||
using idx_valid_pair_t = thrust::tuple<cudf::size_type, bool>; | ||
|
||
/** | ||
* @brief Functor used by `replace_nulls(replace_policy)` to determine the index to gather from in | ||
* the result column. | ||
* | ||
* Binary functor passed to `inclusive_scan` or `inclusive_scan_by_key`. Arguments are a tuple of | ||
* index and validity of a row. Returns a tuple of current index and a discarded boolean if current | ||
* row is valid, otherwise a tuple of the nearest non-null row index and a discarded boolean. | ||
*/ | ||
struct replace_policy_functor { | ||
__device__ idx_valid_pair_t operator()(idx_valid_pair_t const& lhs, idx_valid_pair_t const& rhs) | ||
{ | ||
return thrust::get<1>(rhs) ? thrust::make_tuple(thrust::get<0>(rhs), true) | ||
: thrust::make_tuple(thrust::get<0>(lhs), true); | ||
} | ||
}; | ||
|
||
} // namespace detail | ||
} // namespace cudf |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
/* | ||
* Copyright (c) 2021, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
#include <cudf/column/column_device_view.cuh> | ||
#include <cudf/detail/gather.cuh> | ||
#include <cudf/detail/groupby/group_replace_nulls.hpp> | ||
#include <cudf/detail/replace/nulls.cuh> | ||
#include <cudf/replace.hpp> | ||
|
||
#include <rmm/device_uvector.hpp> | ||
|
||
#include <thrust/functional.h> | ||
#include <thrust/iterator/counting_iterator.h> | ||
#include <thrust/iterator/discard_iterator.h> | ||
#include <thrust/iterator/reverse_iterator.h> | ||
|
||
#include <utility> | ||
|
||
namespace cudf { | ||
namespace groupby { | ||
namespace detail { | ||
|
||
std::unique_ptr<column> group_replace_nulls(cudf::column_view const& grouped_value, | ||
device_span<size_type const> group_labels, | ||
cudf::replace_policy replace_policy, | ||
rmm::cuda_stream_view stream, | ||
rmm::mr::device_memory_resource* mr) | ||
{ | ||
cudf::size_type size = grouped_value.size(); | ||
|
||
auto device_in = cudf::column_device_view::create(grouped_value); | ||
auto index = thrust::make_counting_iterator<cudf::size_type>(0); | ||
auto valid_it = cudf::detail::make_validity_iterator(*device_in); | ||
auto in_begin = thrust::make_zip_iterator(thrust::make_tuple(index, valid_it)); | ||
|
||
rmm::device_uvector<cudf::size_type> gather_map(size, stream); | ||
auto gm_begin = thrust::make_zip_iterator( | ||
thrust::make_tuple(gather_map.begin(), thrust::make_discard_iterator())); | ||
|
||
auto func = cudf::detail::replace_policy_functor(); | ||
thrust::equal_to<cudf::size_type> eq; | ||
if (replace_policy == cudf::replace_policy::PRECEDING) { | ||
thrust::inclusive_scan_by_key(rmm::exec_policy(stream), | ||
group_labels.begin(), | ||
group_labels.begin() + size, | ||
in_begin, | ||
gm_begin, | ||
eq, | ||
func); | ||
} else { | ||
auto gl_rbegin = thrust::make_reverse_iterator(group_labels.begin() + size); | ||
auto in_rbegin = thrust::make_reverse_iterator(in_begin + size); | ||
auto gm_rbegin = thrust::make_reverse_iterator(gm_begin + size); | ||
thrust::inclusive_scan_by_key( | ||
rmm::exec_policy(stream), gl_rbegin, gl_rbegin + size, in_rbegin, gm_rbegin, eq, func); | ||
} | ||
|
||
auto output = cudf::detail::gather(cudf::table_view({grouped_value}), | ||
gather_map.begin(), | ||
gather_map.end(), | ||
cudf::out_of_bounds_policy::DONT_CHECK, | ||
stream, | ||
mr); | ||
|
||
return std::move(output->release()[0]); | ||
} | ||
|
||
} // namespace detail | ||
} // namespace groupby | ||
} // namespace cudf |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.