Skip to content

Commit

Permalink
Add groupby::replace_nulls(replace_policy) api (#7118)
Browse files Browse the repository at this point in the history
Part 1 of #4896, follow up of #6907 

This PR provides a groupby version of the `replace_nulls(replace_policy)` function. A regular `replace_nulls(replace_policy)` operation updates the nulls with the first non-null value that precedes/follows the null. The groupby version is similar, with an exception that the non-null value look-up is bounded by groups.

Here is an example to illustrate the API input/output behavior:

```python
#Input:
keys = [2, 1, 2, 1]
values = [3, 4, NULL, NULL]

#Output, group order is not guaranteed:
sorted_keys = [1, 1, 2, 2]
result = [4, 4, 3, 3]
```

Authors:
  - Michael Wang (https://github.com/isVoid)

Approvers:
  - AJ Schmidt (https://github.com/ajschmidt8)
  - https://github.com/nvdbaranec
  - https://github.com/brandon-b-miller
  - Jake Hemstad (https://github.com/jrhemstad)

URL: #7118
  • Loading branch information
isVoid authored May 24, 2021
1 parent c398054 commit 6dbf2d5
Show file tree
Hide file tree
Showing 12 changed files with 660 additions and 18 deletions.
1 change: 1 addition & 0 deletions conda/recipes/libcudf/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ test:
- test -f $PREFIX/include/cudf/detail/gather.hpp
- test -f $PREFIX/include/cudf/detail/groupby.hpp
- test -f $PREFIX/include/cudf/detail/groupby/sort_helper.hpp
- test -f $PREFIX/include/cudf/detail/groupby/group_replace_nulls.hpp
- test -f $PREFIX/include/cudf/detail/hashing.hpp
- test -f $PREFIX/include/cudf/detail/interop.hpp
- test -f $PREFIX/include/cudf/detail/is_element_valid.hpp
Expand Down
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ add_library(cudf
src/groupby/sort/group_max_scan.cu
src/groupby/sort/group_min_scan.cu
src/groupby/sort/group_sum_scan.cu
src/groupby/sort/group_replace_nulls.cu
src/groupby/sort/sort_helper.cu
src/hash/hashing.cu
src/hash/md5_hash.cu
Expand Down
47 changes: 47 additions & 0 deletions cpp/include/cudf/detail/groupby/group_replace_nulls.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cudf/column/column_view.hpp>
#include <cudf/replace.hpp>
#include <cudf/types.hpp>
#include <cudf/utilities/span.hpp>

#include <rmm/exec_policy.hpp>
namespace cudf {
namespace groupby {
namespace detail {

/**
* @brief Internal API to replace nulls with preceding/following non-null values in @p value
*
* @param[in] grouped_value A column whose null values will be replaced.
* @param[in] group_labels Group labels for @p grouped_value, corresponding to group keys.
* @param[in] replace_policy Specify the position of replacement values relative to null values.
* @param stream CUDA stream used for device memory operations and kernel launches.
* @param[in] mr Device memory resource used to allocate device memory of the returned column.
*/
std::unique_ptr<column> group_replace_nulls(
cudf::column_view const& grouped_value,
device_span<size_type const> group_labels,
cudf::replace_policy replace_policy,
rmm::cuda_stream_view stream = rmm::cuda_stream_default,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

} // namespace detail
} // namespace groupby
} // namespace cudf
44 changes: 44 additions & 0 deletions cpp/include/cudf/detail/replace/nulls.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <cudf/types.hpp>

#include <thrust/functional.h>

namespace cudf {
namespace detail {

using idx_valid_pair_t = thrust::tuple<cudf::size_type, bool>;

/**
* @brief Functor used by `replace_nulls(replace_policy)` to determine the index to gather from in
* the result column.
*
* Binary functor passed to `inclusive_scan` or `inclusive_scan_by_key`. Arguments are a tuple of
* index and validity of a row. Returns a tuple of current index and a discarded boolean if current
* row is valid, otherwise a tuple of the nearest non-null row index and a discarded boolean.
*/
struct replace_policy_functor {
__device__ idx_valid_pair_t operator()(idx_valid_pair_t const& lhs, idx_valid_pair_t const& rhs)
{
return thrust::get<1>(rhs) ? thrust::make_tuple(thrust::get<0>(rhs), true)
: thrust::make_tuple(thrust::get<0>(lhs), true);
}
};

} // namespace detail
} // namespace cudf
43 changes: 43 additions & 0 deletions cpp/include/cudf/groupby.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
#pragma once

#include <cudf/aggregation.hpp>
#include <cudf/column/column_view.hpp>
#include <cudf/replace.hpp>
#include <cudf/table/table_view.hpp>
#include <cudf/types.hpp>
#include <cudf/utilities/span.hpp>

#include <memory>
#include <rmm/cuda_stream_view.hpp>

#include <utility>
Expand Down Expand Up @@ -287,6 +290,46 @@ class groupby {
groups get_groups(cudf::table_view values = {},
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
* @brief Performs grouped replace nulls on @p value
*
* For each `value[i] == NULL` in group `j`, `value[i]` is replaced with the first non-null value
* in group `j` that precedes or follows `value[i]`. If a non-null value is not found in the
* specified direction, `value[i]` is left NULL.
*
* The returned pair contains a column of the sorted keys and the result column. In result column,
* values of the same group are in contiguous memory. In each group, the order of values maintain
* their original order. The order of groups are not guaranteed.
*
* Example:
* @code{.pseudo}
*
* //Inputs:
* keys: {3 3 1 3 1 3 4}
* {2 2 1 2 1 2 5}
* values: {3 4 7 @ @ @ @}
* {@ @ @ "x" "tt" @ @}
* replace_policies: {FORWARD, BACKWARD}
*
* //Outputs (group orders may be different):
* keys: {3 3 3 3 1 1 4}
* {2 2 2 2 1 1 5}
* result: {3 4 4 4 7 7 @}
* {"x" "x" "x" @ "tt" "tt" @}
* @endcode
*
* @param[in] values A table whose column null values will be replaced.
* @param[in] replace_policies Specify the position of replacement values relative to null values,
* one for each column
* @param[in] mr Device memory resource used to allocate device memory of the returned column.
*
* @return Pair that contains a table with the sorted keys and the result column
*/
std::pair<std::unique_ptr<table>, std::unique_ptr<table>> replace_nulls(
table_view const& values,
host_span<cudf::replace_policy const> replace_policies,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

private:
table_view _keys; ///< Keys that determine grouping
null_policy _include_null_keys{null_policy::EXCLUDE}; ///< Include rows in keys
Expand Down
30 changes: 30 additions & 0 deletions cpp/src/groupby/groupby.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <cudf/detail/copy.hpp>
#include <cudf/detail/gather.hpp>
#include <cudf/detail/groupby.hpp>
#include <cudf/detail/groupby/group_replace_nulls.hpp>
#include <cudf/detail/groupby/sort_helper.hpp>
#include <cudf/detail/nvtx/ranges.hpp>
#include <cudf/dictionary/dictionary_column_view.hpp>
Expand Down Expand Up @@ -223,6 +224,35 @@ groupby::groups groupby::get_groups(table_view values, rmm::mr::device_memory_re
}
}

std::pair<std::unique_ptr<table>, std::unique_ptr<table>> groupby::replace_nulls(
table_view const& values,
host_span<cudf::replace_policy const> replace_policies,
rmm::mr::device_memory_resource* mr)
{
CUDF_FUNC_RANGE();
CUDF_EXPECTS(_keys.num_rows() == values.num_rows(),
"Size mismatch between group labels and value.");
CUDF_EXPECTS(static_cast<cudf::size_type>(replace_policies.size()) == values.num_columns(),
"Size mismatch between num_columns and replace_policies.");

if (values.is_empty()) { return std::make_pair(empty_like(_keys), empty_like(values)); }
auto const stream = rmm::cuda_stream_default;

auto const& group_labels = helper().group_labels(stream);
std::vector<std::unique_ptr<column>> results;
std::transform(thrust::make_counting_iterator(0),
thrust::make_counting_iterator(values.num_columns()),
std::back_inserter(results),
[&](auto i) {
auto grouped_values = helper().grouped_values(values.column(i), stream);
return detail::group_replace_nulls(
grouped_values->view(), group_labels, replace_policies[i], stream, mr);
});

return std::make_pair(std::move(helper().sorted_keys(stream, mr)),
std::make_unique<table>(std::move(results)));
}

// Get the sort helper object
detail::sort::sort_groupby_helper& groupby::helper()
{
Expand Down
82 changes: 82 additions & 0 deletions cpp/src/groupby/sort/group_replace_nulls.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cudf/column/column_device_view.cuh>
#include <cudf/detail/gather.cuh>
#include <cudf/detail/groupby/group_replace_nulls.hpp>
#include <cudf/detail/replace/nulls.cuh>
#include <cudf/replace.hpp>

#include <rmm/device_uvector.hpp>

#include <thrust/functional.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/iterator/reverse_iterator.h>

#include <utility>

namespace cudf {
namespace groupby {
namespace detail {

std::unique_ptr<column> group_replace_nulls(cudf::column_view const& grouped_value,
device_span<size_type const> group_labels,
cudf::replace_policy replace_policy,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
cudf::size_type size = grouped_value.size();

auto device_in = cudf::column_device_view::create(grouped_value);
auto index = thrust::make_counting_iterator<cudf::size_type>(0);
auto valid_it = cudf::detail::make_validity_iterator(*device_in);
auto in_begin = thrust::make_zip_iterator(thrust::make_tuple(index, valid_it));

rmm::device_uvector<cudf::size_type> gather_map(size, stream);
auto gm_begin = thrust::make_zip_iterator(
thrust::make_tuple(gather_map.begin(), thrust::make_discard_iterator()));

auto func = cudf::detail::replace_policy_functor();
thrust::equal_to<cudf::size_type> eq;
if (replace_policy == cudf::replace_policy::PRECEDING) {
thrust::inclusive_scan_by_key(rmm::exec_policy(stream),
group_labels.begin(),
group_labels.begin() + size,
in_begin,
gm_begin,
eq,
func);
} else {
auto gl_rbegin = thrust::make_reverse_iterator(group_labels.begin() + size);
auto in_rbegin = thrust::make_reverse_iterator(in_begin + size);
auto gm_rbegin = thrust::make_reverse_iterator(gm_begin + size);
thrust::inclusive_scan_by_key(
rmm::exec_policy(stream), gl_rbegin, gl_rbegin + size, in_rbegin, gm_rbegin, eq, func);
}

auto output = cudf::detail::gather(cudf::table_view({grouped_value}),
gather_map.begin(),
gather_map.end(),
cudf::out_of_bounds_policy::DONT_CHECK,
stream,
mr);

return std::move(output->release()[0]);
}

} // namespace detail
} // namespace groupby
} // namespace cudf
23 changes: 6 additions & 17 deletions cpp/src/replace/nulls.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <cudf/detail/null_mask.hpp>
#include <cudf/detail/nvtx/ranges.hpp>
#include <cudf/detail/replace.hpp>
#include <cudf/detail/replace/nulls.cuh>
#include <cudf/detail/utilities/cuda.cuh>
#include <cudf/dictionary/detail/replace.hpp>
#include <cudf/dictionary/dictionary_column_view.hpp>
Expand All @@ -40,10 +41,13 @@
#include <cudf/utilities/type_dispatcher.hpp>

#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_uvector.hpp>

#include <thrust/functional.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/iterator/reverse_iterator.h>
#include <thrust/iterator/zip_iterator.h>
#include <thrust/scan.h>
#include <thrust/transform.h>

namespace { // anonymous
Expand Down Expand Up @@ -356,22 +360,6 @@ std::unique_ptr<cudf::column> replace_nulls_scalar_kernel_forwarder::operator()<
return cudf::dictionary::detail::replace_nulls(dict_input, replacement, stream, mr);
}

/**
* @brief Functor used by `inclusive_scan` to determine the index to gather from in
* the result column. When current row in input column is NULL, return previous
* accumulated index, otherwise return the current index. The second element in
* the return tuple is discarded.
*/
struct replace_policy_functor {
__device__ thrust::tuple<cudf::size_type, bool> operator()(
thrust::tuple<cudf::size_type, bool> const& lhs,
thrust::tuple<cudf::size_type, bool> const& rhs)
{
return thrust::get<1>(rhs) ? thrust::make_tuple(thrust::get<0>(rhs), true)
: thrust::make_tuple(thrust::get<0>(lhs), true);
}
};

/**
* @brief Function used by replace_nulls policy
*/
Expand All @@ -390,7 +378,7 @@ std::unique_ptr<cudf::column> replace_nulls_policy_impl(cudf::column_view const&
auto gm_begin = thrust::make_zip_iterator(
thrust::make_tuple(gather_map.begin(), thrust::make_discard_iterator()));

auto func = replace_policy_functor();
auto func = cudf::detail::replace_policy_functor();
if (replace_policy == cudf::replace_policy::PRECEDING) {
thrust::inclusive_scan(
rmm::exec_policy(stream), in_begin, in_begin + input.size(), gm_begin, func);
Expand All @@ -414,6 +402,7 @@ std::unique_ptr<cudf::column> replace_nulls_policy_impl(cudf::column_view const&

namespace cudf {
namespace detail {

std::unique_ptr<cudf::column> replace_nulls(cudf::column_view const& input,
cudf::column_view const& replacement,
rmm::cuda_stream_view stream,
Expand Down
1 change: 1 addition & 0 deletions cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ ConfigureTest(GROUPBY_TEST
groupby/nunique_tests.cpp
groupby/product_tests.cpp
groupby/quantile_tests.cpp
groupby/replace_nulls_tests.cpp
groupby/shift_tests.cpp
groupby/std_tests.cpp
groupby/sum_of_squares_tests.cpp
Expand Down
Loading

0 comments on commit 6dbf2d5

Please sign in to comment.