Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support null_policy::EXCLUDE for COLLECT rolling aggregation #7264

Merged
merged 10 commits into from
Feb 18, 2021
1 change: 1 addition & 0 deletions conda/recipes/libcudf/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ test:
- test -f $PREFIX/include/cudf_test/cudf_gtest.hpp
- test -f $PREFIX/include/cudf_test/cxxopts.hpp
- test -f $PREFIX/include/cudf_test/file_utilities.hpp
- test -f $PREFIX/include/cudf_test/iterator_utilities.hpp
- test -f $PREFIX/include/cudf_test/table_utilities.hpp
- test -f $PREFIX/include/cudf_test/timestamp_utilities.cuh
- test -f $PREFIX/include/cudf_test/type_list_utilities.hpp
Expand Down
14 changes: 12 additions & 2 deletions cpp/include/cudf/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,18 @@ std::unique_ptr<aggregation> make_nth_element_aggregation(
/// Factory to create a ROW_NUMBER aggregation
std::unique_ptr<aggregation> make_row_number_aggregation();

/// Factory to create a COLLECT_NUMBER aggregation
std::unique_ptr<aggregation> make_collect_aggregation();
/**
* @brief Factory to create a COLLECT aggregation
*
* `COLLECT` returns a list column of all included elements in the group/series.
*
* If `null_handling` is set to `EXCLUDE`, null elements are dropped from each
* of the list rows.
*
* @param null_handling Indicates whether to include/exclude nulls in list elements.
*/
std::unique_ptr<aggregation> make_collect_aggregation(
null_policy null_handling = null_policy::INCLUDE);

/// Factory to create a LAG aggregation
std::unique_ptr<aggregation> make_lag_aggregation(size_type offset);
Expand Down
24 changes: 24 additions & 0 deletions cpp/include/cudf/detail/aggregation/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ struct quantile_aggregation final : derived_aggregation<quantile_aggregation> {
}
};

/**
* @brief Derived aggregation class for specifying LEAD/LAG window aggregations
*/
struct lead_lag_aggregation final : derived_aggregation<lead_lag_aggregation> {
lead_lag_aggregation(Kind kind, size_type offset)
: derived_aggregation{offset < 0 ? (kind == LAG ? LEAD : LAG) : kind},
Expand Down Expand Up @@ -316,6 +319,27 @@ struct udf_aggregation final : derived_aggregation<udf_aggregation> {
}
};

/**
* @brief Derived aggregation class for specifying COLLECT aggregation
*/
struct collect_list_aggregation final : derived_aggregation<nunique_aggregation> {
explicit collect_list_aggregation(null_policy null_handling = null_policy::INCLUDE)
: derived_aggregation{COLLECT}, _null_handling{null_handling}
{
}
null_policy _null_handling; ///< include or exclude nulls

protected:
friend class derived_aggregation<nunique_aggregation>;

bool operator==(nunique_aggregation const& other) const
{
return _null_handling == other._null_handling;
}

size_t hash_impl() const { return std::hash<int>{}(static_cast<int>(_null_handling)); }
};

/**
* @brief Sentinel value used for `ARGMAX` aggregation.
*
Expand Down
73 changes: 73 additions & 0 deletions cpp/include/cudf_test/iterator_utilities.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cudf/detail/iterator.cuh>
#include <cudf/types.hpp>

#include <thrust/iterator/transform_iterator.h>

namespace cudf {
namespace test {

/**
* @brief Bool iterator for marking (possibly multiple) null elements in a column_wrapper.
*
* The returned iterator yields `false` (to mark `null`) at all the specified indices,
* and yields `true` (to mark valid rows) for all other indices. E.g.
*
* @code
* auto iter = iterator_with_null_at(std::vector<size_type>{8,9});
* iter[6] == true; // i.e. Valid row at index 6.
* iter[7] == true; // i.e. Valid row at index 7.
* iter[8] == false; // i.e. Invalid row at index 8.
* iter[9] == false; // i.e. Invalid row at index 9.
* @endcode
*
* @param indices The collection of indices for which the validity iterator
* must return `false` (i.e. null)
* @return auto Validity iterator
*/
static auto iterator_with_null_at(std::vector<cudf::size_type> const& indices)
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
{
return cudf::detail::make_counting_transform_iterator(0, [indices](auto i) {
return std::find(indices.begin(), indices.end(), i) == indices.cend();
});
}

/**
* @brief Bool iterator for marking a single null element in a column_wrapper
*
* The returned iterator yields `false` (to mark `null`) at the specified index,
* and yields `true` (to mark valid rows) for all other indices. E.g.
*
* @code
* auto iter = iterator_with_null_at(8);
* iter[7] == true; // i.e. Valid row at index 7.
* iter[8] == false; // i.e. Invalid row at index 8.
* @endcode
*
* @param index The index for which the validity iterator must return `false` (i.e. null)
* @return auto Validity iterator
*/
static auto iterator_with_null_at(cudf::size_type const& index)
{
return iterator_with_null_at(std::vector<cudf::size_type>{index});
}

} // namespace test
} // namespace cudf
4 changes: 2 additions & 2 deletions cpp/src/aggregation/aggregation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ std::unique_ptr<aggregation> make_row_number_aggregation()
return std::make_unique<aggregation>(aggregation::ROW_NUMBER);
}
/// Factory to create a COLLECT aggregation
std::unique_ptr<aggregation> make_collect_aggregation()
std::unique_ptr<aggregation> make_collect_aggregation(null_policy null_handling)
{
return std::make_unique<aggregation>(aggregation::COLLECT);
return std::make_unique<detail::collect_list_aggregation>(null_handling);
}
/// Factory to create a LAG aggregation
std::unique_ptr<aggregation> make_lag_aggregation(size_type offset)
Expand Down
5 changes: 5 additions & 0 deletions cpp/src/groupby/sort/groupby.cu
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,11 @@ void store_result_functor::operator()<aggregation::NTH_ELEMENT>(aggregation cons
template <>
void store_result_functor::operator()<aggregation::COLLECT>(aggregation const& agg)
{
auto null_handling =
static_cast<cudf::detail::collect_list_aggregation const&>(agg)._null_handling;
CUDF_EXPECTS(null_handling == null_policy::INCLUDE,
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
"null exclusion is not supported on groupby COLLECT aggregation.");

if (cache.has_result(col_idx, agg)) return;

auto result = detail::group_collect(
Expand Down
88 changes: 88 additions & 0 deletions cpp/src/rolling/rolling_detail.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1063,6 +1063,81 @@ struct rolling_window_launcher {
return gather_map;
}

/**
* @brief Count null entries in result of COLLECT.
*/
size_type count_child_nulls(column_view const& input,
std::unique_ptr<column> const& gather_map,
rmm::cuda_stream_view stream)
{
auto input_device_view = column_device_view::create(input, stream);

auto input_row_is_null = [d_input = *input_device_view] __device__(auto i) {
return d_input.is_null_nocheck(i);
};

return thrust::count_if(rmm::exec_policy(stream),
gather_map->view().template begin<size_type>(),
gather_map->view().template end<size_type>(),
input_row_is_null);
}

/**
* @brief Purge entries for null inputs from gather_map, and adjust offsets.
*/
std::pair<std::unique_ptr<column>, std::unique_ptr<column>> purge_null_entries(
column_view const& input,
std::unique_ptr<column> const& gather_map,
std::unique_ptr<column> const& offsets,
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
size_type num_child_nulls,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
auto input_device_view = column_device_view::create(input, stream);

auto input_row_not_null = [d_input = *input_device_view] __device__(auto i) {
return d_input.is_valid_nocheck(i);
};

// Purge entries in gather_map that correspond to null input.
auto new_gather_map = make_fixed_width_column(data_type{type_to_id<size_type>()},
gather_map->size() - num_child_nulls,
mask_state::UNALLOCATED,
stream,
mr);
thrust::copy_if(rmm::exec_policy(stream),
gather_map->view().template begin<size_type>(),
gather_map->view().template end<size_type>(),
new_gather_map->mutable_view().template begin<size_type>(),
input_row_not_null);

// Recalculate offsets after null entries are purged.
auto new_sizes = make_fixed_width_column(
data_type{type_to_id<size_type>()}, input.size(), mask_state::UNALLOCATED, stream, mr);

thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(input.size()),
new_sizes->mutable_view().template begin<size_type>(),
[d_gather_map = gather_map->view().template begin<size_type>(),
d_old_offsets = offsets->view().template begin<size_type>(),
input_row_not_null] __device__(auto i) {
return thrust::count_if(thrust::seq,
d_gather_map + d_old_offsets[i],
d_gather_map + d_old_offsets[i + 1],
input_row_not_null);
});
mythrocks marked this conversation as resolved.
Show resolved Hide resolved

auto new_offsets =
strings::detail::make_offsets_child_column(new_sizes->view().template begin<size_type>(),
new_sizes->view().template end<size_type>(),
stream,
mr);

return std::make_pair<std::unique_ptr<column>, std::unique_ptr<column>>(
std::move(new_gather_map), std::move(new_offsets));
}

template <aggregation::Kind op, typename PrecedingIter, typename FollowingIter>
std::enable_if_t<(op == aggregation::COLLECT), std::unique_ptr<column>> operator()(
column_view const& input,
Expand Down Expand Up @@ -1106,6 +1181,19 @@ struct rolling_window_launcher {
auto gather_map = create_collect_gather_map(
offsets->view(), per_row_mapping->view(), preceding_begin, stream, mr);

// If gather_map collects null elements, and null_policy == EXCLUDE,
// those elements must be filtered out, and offsets recomputed.
auto null_handling = static_cast<collect_list_aggregation*>(agg.get())->_null_handling;
if (null_handling == null_policy::EXCLUDE && input.has_nulls()) {
auto num_child_nulls = count_child_nulls(input, gather_map, stream);
if (num_child_nulls != 0) {
auto new_gather_map_and_offsets =
purge_null_entries(input, gather_map, offsets, num_child_nulls, stream, mr);
gather_map = std::move(new_gather_map_and_offsets.first);
offsets = std::move(new_gather_map_and_offsets.second);
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
}
}

// gather(), to construct child column.
auto gather_output =
cudf::gather(table_view{std::vector<column_view>{input}}, gather_map->view());
Expand Down
Loading