Skip to content

Commit

Permalink
Merge pull request rapidsai#6 from thirtiseven/min_by_cudf_sort_only
Browse files Browse the repository at this point in the history
Min by cudf sort only
  • Loading branch information
wjxiz1992 authored Jul 5, 2024
2 parents ab10f5a + 8b8ecda commit af06a0b
Show file tree
Hide file tree
Showing 13 changed files with 288 additions and 2 deletions.
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ add_library(
src/groupby/sort/group_m2.cu
src/groupby/sort/group_max.cu
src/groupby/sort/group_min.cu
src/groupby/sort/group_min_by.cu
src/groupby/sort/group_merge_lists.cu
src/groupby/sort/group_merge_m2.cu
src/groupby/sort/group_nth_element.cu
Expand Down
14 changes: 13 additions & 1 deletion cpp/include/cudf/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ class aggregation {
TDIGEST, ///< create a tdigest from a set of input values
MERGE_TDIGEST, ///< create a tdigest by merging multiple tdigests together
HISTOGRAM, ///< compute frequency of each element
MERGE_HISTOGRAM ///< merge partial values of HISTOGRAM aggregation,
MERGE_HISTOGRAM, ///< merge partial values of HISTOGRAM aggregation,
MIN_BY ///< min reduction by another column
};

aggregation() = delete;
Expand Down Expand Up @@ -381,6 +382,17 @@ std::unique_ptr<Base> make_argmax_aggregation();
template <typename Base = aggregation>
std::unique_ptr<Base> make_argmin_aggregation();

/**
* @brief Factory to create a MIN_BY aggregation
*
* `MIN_BY` returns the value of the element in the group that is the minimum
* according to the order_by column.
*
* @return A MIN_BY aggregation object
*/
template <typename Base = aggregation>
std::unique_ptr<Base> make_min_by_aggregation();

/**
* @brief Factory to create a NUNIQUE aggregation
*
Expand Down
29 changes: 29 additions & 0 deletions cpp/include/cudf/detail/aggregation/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class simple_aggregations_collector { // Declares the interface for the simple
class product_aggregation const& agg);
virtual std::vector<std::unique_ptr<aggregation>> visit(data_type col_type,
class min_aggregation const& agg);
virtual std::vector<std::unique_ptr<aggregation>> visit(data_type col_type,
class min_by_aggregation const& agg);
virtual std::vector<std::unique_ptr<aggregation>> visit(data_type col_type,
class max_aggregation const& agg);
virtual std::vector<std::unique_ptr<aggregation>> visit(data_type col_type,
Expand Down Expand Up @@ -217,6 +219,25 @@ class min_aggregation final : public rolling_aggregation,
void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); }
};

/**
* @brief Derived class for specifying a min_by aggregation
*/
class min_by_aggregation final : public groupby_aggregation, public reduce_aggregation {
public:
min_by_aggregation() : aggregation(MIN_BY) {}

[[nodiscard]] std::unique_ptr<aggregation> clone() const override
{
return std::make_unique<min_by_aggregation>(*this);
}
std::vector<std::unique_ptr<aggregation>> get_simple_aggregations(
data_type col_type, simple_aggregations_collector& collector) const override
{
return collector.visit(col_type, *this);
}
void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); }
};

/**
* @brief Derived class for specifying a max aggregation
*/
Expand Down Expand Up @@ -1219,6 +1240,12 @@ struct target_type_impl<Source, aggregation::MIN> {
using type = Source;
};

// Computing MIN_BY of Source, use Source accumulator
template <typename Source>
struct target_type_impl<Source, aggregation::MIN_BY> {
using type = struct_view;
};

// Computing MAX of Source, use Source accumulator
template <typename Source>
struct target_type_impl<Source, aggregation::MAX> {
Expand Down Expand Up @@ -1517,6 +1544,8 @@ CUDF_HOST_DEVICE inline decltype(auto) aggregation_dispatcher(aggregation::Kind
return f.template operator()<aggregation::PRODUCT>(std::forward<Ts>(args)...);
case aggregation::MIN:
return f.template operator()<aggregation::MIN>(std::forward<Ts>(args)...);
case aggregation::MIN_BY:
return f.template operator()<aggregation::MIN_BY>(std::forward<Ts>(args)...);
case aggregation::MAX:
return f.template operator()<aggregation::MAX>(std::forward<Ts>(args)...);
case aggregation::COUNT_VALID:
Expand Down
15 changes: 15 additions & 0 deletions cpp/src/aggregation/aggregation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ std::vector<std::unique_ptr<aggregation>> simple_aggregations_collector::visit(
return visit(col_type, static_cast<aggregation const&>(agg));
}

std::vector<std::unique_ptr<aggregation>> simple_aggregations_collector::visit(
data_type col_type, min_by_aggregation const& agg)
{
return visit(col_type, static_cast<aggregation const&>(agg));
}

std::vector<std::unique_ptr<aggregation>> simple_aggregations_collector::visit(
data_type col_type, max_aggregation const& agg)
{
Expand Down Expand Up @@ -637,6 +643,15 @@ template std::unique_ptr<aggregation> make_argmin_aggregation<aggregation>();
template std::unique_ptr<rolling_aggregation> make_argmin_aggregation<rolling_aggregation>();
template std::unique_ptr<groupby_aggregation> make_argmin_aggregation<groupby_aggregation>();

/// Factory to create a MIN_BY aggregation
template <typename Base>
std::unique_ptr<Base> make_min_by_aggregation()
{
return std::make_unique<detail::min_by_aggregation>();
}
template std::unique_ptr<aggregation> make_min_by_aggregation<aggregation>();
template std::unique_ptr<groupby_aggregation> make_min_by_aggregation<groupby_aggregation>();

/// Factory to create an NUNIQUE aggregation
template <typename Base>
std::unique_ptr<Base> make_nunique_aggregation(null_policy null_handling)
Expand Down
12 changes: 12 additions & 0 deletions cpp/src/groupby/sort/aggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,18 @@ void aggregate_result_functor::operator()<aggregation::MIN>(aggregation const& a
cache.add_result(values, agg, std::move(result));
}

template <>
void aggregate_result_functor::operator()<aggregation::MIN_BY>(aggregation const& agg)
{
if (cache.has_result(values, agg)) return;

cache.add_result(
values,
agg,
detail::group_min_by(
get_grouped_values(), helper.group_labels(stream), helper.num_groups(stream), stream, mr));
}

template <>
void aggregate_result_functor::operator()<aggregation::MAX>(aggregation const& agg)
{
Expand Down
94 changes: 94 additions & 0 deletions cpp/src/groupby/sort/group_min_by.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "groupby/sort/group_single_pass_reduction_util.cuh"

#include <cudf/detail/gather.hpp>
#include <cudf/utilities/span.hpp>

#include <rmm/cuda_stream_view.hpp>
#include <rmm/resource_ref.hpp>

#include <thrust/gather.h>

namespace cudf {
namespace groupby {
namespace detail {
std::unique_ptr<column> group_min_by(column_view const& structs_column,
cudf::device_span<size_type const> group_labels,
size_type num_groups,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
auto const values = structs_column.child(0);
auto const orders = structs_column.child(1);

// Nulls in orders column should be excluded, so we need to create a new bitmask
// that is the combination of the nulls in both values and orders columns.
auto const new_mask_buffer_cnt = bitmask_and(table_view{{values, orders}});

std::vector<column_view> struct_children(values.num_children());
for (size_type i = 0; i < values.num_children(); i++) {
struct_children[i] = values.child(i);
}

column_view const values_null_excluded(
values.type(),
values.size(),
values.head(),
static_cast<bitmask_type const*>(new_mask_buffer_cnt.first.data()),
new_mask_buffer_cnt.second,
values.offset(),
struct_children);

column_view const structs_column_null_excluded(
structs_column.type(),
structs_column.size(),
structs_column.head(),
nullptr,
0,
structs_column.offset(),
{values_null_excluded, orders});

auto const indices = type_dispatcher(orders.type(),
group_reduction_dispatcher<aggregation::ARGMIN>{},
orders,
num_groups,
group_labels,
stream,
mr);

column_view const null_removed_map(
data_type(type_to_id<size_type>()),
indices->size(),
static_cast<void const*>(indices->view().template data<size_type>()),
nullptr,
0);

auto res = cudf::detail::gather(table_view{{structs_column_null_excluded}},
null_removed_map,
indices->nullable() ? cudf::out_of_bounds_policy::NULLIFY
: cudf::out_of_bounds_policy::DONT_CHECK,
cudf::detail::negative_index_policy::NOT_ALLOWED,
stream,
mr);

return std::move(res->release()[0]);
}

} // namespace detail
} // namespace groupby
} // namespace cudf
6 changes: 6 additions & 0 deletions cpp/src/groupby/sort/group_reductions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ std::unique_ptr<column> group_min(column_view const& values,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr);

std::unique_ptr<column> group_min_by(column_view const& structs_column,
cudf::device_span<size_type const> group_labels,
size_type num_groups,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr);

/**
* @brief Internal API to calculate groupwise maximum value
*
Expand Down
1 change: 1 addition & 0 deletions cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ ConfigureTest(
groupby/keys_tests.cpp
groupby/lists_tests.cpp
groupby/m2_tests.cpp
groupby/min_by_tests.cpp
groupby/min_tests.cpp
groupby/max_scan_tests.cpp
groupby/max_tests.cpp
Expand Down
77 changes: 77 additions & 0 deletions cpp/tests/groupby/min_by_tests.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <tests/groupby/groupby_test_util.hpp>

#include <cudf_test/base_fixture.hpp>
#include <cudf_test/column_wrapper.hpp>
#include <cudf_test/iterator_utilities.hpp>
#include <cudf_test/type_lists.hpp>

#include <cudf/detail/aggregation/aggregation.hpp>

using namespace cudf::test::iterators;

template <typename V>
struct groupby_min_by_test : public cudf::test::BaseFixture {};
using K = int32_t;

TYPED_TEST_SUITE(groupby_min_by_test, cudf::test::FixedWidthTypes);

TYPED_TEST(groupby_min_by_test, basic)
{
using V = TypeParam;

if (std::is_same_v<V, bool>) return;

cudf::test::fixed_width_column_wrapper<K> keys{1, 2, 3, 1, 2, 2, 1, 3, 3, 2};
cudf::test::fixed_width_column_wrapper<K> values{4, 1, 2, 3, 4, 5, 6, 7, 8, 9};
cudf::test::fixed_width_column_wrapper<V> orders{1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
cudf::test::structs_column_wrapper vals{values, orders};

cudf::test::fixed_width_column_wrapper<K> expect_keys{1, 2, 3};
cudf::test::fixed_width_column_wrapper<K> expect_values{4, 1, 2};
cudf::test::fixed_width_column_wrapper<V> expect_orders{1, 2, 3};
cudf::test::structs_column_wrapper expect_vals{expect_values, expect_orders};

auto agg = cudf::make_min_by_aggregation<cudf::groupby_aggregation>();
test_single_agg(keys, vals, expect_keys, expect_vals, std::move(agg));

auto agg2 = cudf::make_min_by_aggregation<cudf::groupby_aggregation>();
test_single_agg(keys, vals, expect_keys, expect_vals, std::move(agg2), force_use_sort_impl::YES);
}

struct groupby_min_by_string_test : public cudf::test::BaseFixture {};

TEST_F(groupby_min_by_string_test, basic)
{
cudf::test::fixed_width_column_wrapper<K> keys{1, 2, 3, 1, 2, 2, 1, 3, 3, 2};
cudf::test::fixed_width_column_wrapper<K> values{4, 1, 2, 3, 4, 5, 6, 7, 8, 9};
cudf::test::strings_column_wrapper orders{
"año", "bit", "₹1", "aaa", "zit", "bat", "aab", "$1", "€1", "wut"};
cudf::test::structs_column_wrapper vals{values, orders};

cudf::test::fixed_width_column_wrapper<K> expect_keys{1, 2, 3};
cudf::test::fixed_width_column_wrapper<K> expect_values{3, 5, 7};
cudf::test::strings_column_wrapper expect_orders{"aaa", "bat", "$1"};
cudf::test::structs_column_wrapper expect_vals{expect_values, expect_orders};

auto agg = cudf::make_min_by_aggregation<cudf::groupby_aggregation>();
test_single_agg(keys, vals, expect_keys, expect_vals, std::move(agg));

auto agg2 = cudf::make_min_by_aggregation<cudf::groupby_aggregation>();
test_single_agg(keys, vals, expect_keys, expect_vals, std::move(agg2), force_use_sort_impl::YES);
}
13 changes: 12 additions & 1 deletion java/src/main/java/ai/rapids/cudf/Aggregation.java
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ enum Kind {
TDIGEST(31), // This can take a delta argument for accuracy level
MERGE_TDIGEST(32), // This can take a delta argument for accuracy level
HISTOGRAM(33),
MERGE_HISTOGRAM(34);
MERGE_HISTOGRAM(34),
MIN_BY(35);

final int nativeId;

Expand Down Expand Up @@ -675,6 +676,16 @@ static ArgMinAggregation argMin() {
return new ArgMinAggregation();
}

static MinByAggregation minBy() {
return new MinByAggregation();
}

static final class MinByAggregation extends NoParamAggregation {
private MinByAggregation() {
super(Kind.MIN_BY);
}
}

static final class NuniqueAggregation extends CountLikeAggregation {
private NuniqueAggregation(NullPolicy nullPolicy) {
super(Kind.NUNIQUE, nullPolicy);
Expand Down
7 changes: 7 additions & 0 deletions java/src/main/java/ai/rapids/cudf/GroupByAggregation.java
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,13 @@ public static GroupByAggregation min() {
return new GroupByAggregation(Aggregation.min());
}

/**
* MinBy Aggregation
*/
public static GroupByAggregation minBy() {
return new GroupByAggregation(Aggregation.minBy());
}

/**
* Max Aggregation
*/
Expand Down
2 changes: 2 additions & 0 deletions java/src/main/native/src/AggregationJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createNoParamAgg(JNIEnv*
return cudf::make_histogram_aggregation();
case 34: // MERGE_HISTOGRAM
return cudf::make_merge_histogram_aggregation();
case 35: // MINBY
return cudf::make_min_by_aggregation();

default: throw std::logic_error("Unsupported No Parameter Aggregation Operation");
}
Expand Down
Loading

0 comments on commit af06a0b

Please sign in to comment.