From 200fc0b35216c01235103e491d5217b932670ebc Mon Sep 17 00:00:00 2001 From: Yunsong Wang Date: Thu, 29 Feb 2024 13:25:35 -0800 Subject: [PATCH] Use cuco::static_set in the hash-based groupby (#14813) Depends on https://github.com/rapidsai/cudf/pull/14849 Contributes to #12261 This PR migrates hash groupby to use the new `cuco::static_set` data structure. It doesn't change any existing libcudf behavior but uncovers the fact that the cudf python `value_counts` doesn't guarantee output orders thus the PR becomes a breaking change. Authors: - Yunsong Wang (https://github.com/PointKernel) Approvers: - David Wendt (https://github.com/davidwendt) - Vyas Ramasubramani (https://github.com/vyasr) URL: https://github.com/rapidsai/cudf/pull/14813 --- cpp/benchmarks/groupby/group_max.cpp | 7 +- cpp/benchmarks/groupby/group_struct_keys.cpp | 9 +- cpp/include/cudf/detail/cuco_helpers.hpp | 5 + cpp/src/groupby/hash/groupby.cu | 123 ++++++++---------- cpp/src/groupby/hash/groupby_kernels.cuh | 47 +++---- cpp/src/groupby/hash/multi_pass_kernels.cuh | 13 +- .../source/user_guide/pandas-comparison.md | 2 +- python/cudf/cudf/core/dataframe.py | 4 +- python/cudf/cudf/core/groupby/groupby.py | 28 ++-- python/cudf/cudf/tests/test_groupby.py | 16 ++- 10 files changed, 125 insertions(+), 129 deletions(-) diff --git a/cpp/benchmarks/groupby/group_max.cpp b/cpp/benchmarks/groupby/group_max.cpp index e65c37f001d..b7b330f02e5 100644 --- a/cpp/benchmarks/groupby/group_max.cpp +++ b/cpp/benchmarks/groupby/group_max.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,6 +15,7 @@ */ #include +#include #include @@ -50,9 +51,13 @@ void bench_groupby_max(nvbench::state& state, nvbench::type_list) requests[0].values = vals->view(); requests[0].aggregations.push_back(cudf::make_max_aggregation()); + auto const mem_stats_logger = cudf::memory_stats_logger(); state.set_cuda_stream(nvbench::make_cuda_stream_view(cudf::get_default_stream().value())); state.exec(nvbench::exec_tag::sync, [&](nvbench::launch& launch) { auto const result = gb_obj.aggregate(requests); }); + + state.add_buffer_size( + mem_stats_logger.peak_memory_usage(), "peak_memory_usage", "peak_memory_usage"); } NVBENCH_BENCH_TYPES(bench_groupby_max, diff --git a/cpp/benchmarks/groupby/group_struct_keys.cpp b/cpp/benchmarks/groupby/group_struct_keys.cpp index 44a12c1c30e..cadd9c2d137 100644 --- a/cpp/benchmarks/groupby/group_struct_keys.cpp +++ b/cpp/benchmarks/groupby/group_struct_keys.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,6 +15,7 @@ */ #include +#include #include @@ -80,11 +81,15 @@ void bench_groupby_struct_keys(nvbench::state& state) requests[0].aggregations.push_back(cudf::make_min_aggregation()); // Set up nvbench default stream - auto stream = cudf::get_default_stream(); + auto const mem_stats_logger = cudf::memory_stats_logger(); + auto stream = cudf::get_default_stream(); state.set_cuda_stream(nvbench::make_cuda_stream_view(stream.value())); state.exec(nvbench::exec_tag::sync, [&](nvbench::launch& launch) { auto const result = gb_obj.aggregate(requests); }); + + state.add_buffer_size( + mem_stats_logger.peak_memory_usage(), "peak_memory_usage", "peak_memory_usage"); } NVBENCH_BENCH(bench_groupby_struct_keys) diff --git a/cpp/include/cudf/detail/cuco_helpers.hpp b/cpp/include/cudf/detail/cuco_helpers.hpp index 506f6475637..dca5a39bece 100644 --- a/cpp/include/cudf/detail/cuco_helpers.hpp +++ b/cpp/include/cudf/detail/cuco_helpers.hpp @@ -16,11 +16,16 @@ #pragma once +#include + #include #include namespace cudf::detail { +/// Sentinel value for `cudf::size_type` +static cudf::size_type constexpr CUDF_SIZE_TYPE_SENTINEL = -1; + /// Default load factor for cuco data structures static double constexpr CUCO_DESIRED_LOAD_FACTOR = 0.5; diff --git a/cpp/src/groupby/hash/groupby.cu b/cpp/src/groupby/hash/groupby.cu index 7b85dd02c10..acc1b087510 100644 --- a/cpp/src/groupby/hash/groupby.cu +++ b/cpp/src/groupby/hash/groupby.cu @@ -22,23 +22,19 @@ #include #include #include -#include #include #include #include #include +#include #include #include #include -#include #include -#include -#include #include #include #include #include -#include #include #include #include @@ -49,12 +45,9 @@ #include -#include -#include -#include +#include #include #include -#include #include #include @@ -66,15 +59,12 @@ namespace detail { namespace hash { namespace { -// TODO: replace it with `cuco::static_map` -// https://github.com/rapidsai/cudf/issues/10401 -template -using map_type = concurrent_unordered_map< - cudf::size_type, - cudf::size_type, +// TODO: similar to `contains_table`, using larger CG size like 2 or 4 for nested +// types and `cg_size = 1`for flat data to improve performance +using probing_scheme_type = cuco::linear_probing< + 1, ///< Number of threads used to handle each input key cudf::experimental::row::hash::device_row_hasher, - ComparatorType>; + cudf::nullate::DYNAMIC>>; /** * @brief List of aggregation operations that can be computed with a hash-based @@ -190,14 +180,14 @@ class groupby_simple_aggregations_collector final } }; -template +template class hash_compound_agg_finalizer final : public cudf::detail::aggregation_finalizer { column_view col; data_type result_type; cudf::detail::result_cache* sparse_results; cudf::detail::result_cache* dense_results; device_span gather_map; - map_type const& map; + SetType set; bitmask_type const* __restrict__ row_bitmask; rmm::cuda_stream_view stream; rmm::mr::device_memory_resource* mr; @@ -209,7 +199,7 @@ class hash_compound_agg_finalizer final : public cudf::detail::aggregation_final cudf::detail::result_cache* sparse_results, cudf::detail::result_cache* dense_results, device_span gather_map, - map_type const& map, + SetType set, bitmask_type const* row_bitmask, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) @@ -217,7 +207,7 @@ class hash_compound_agg_finalizer final : public cudf::detail::aggregation_final sparse_results(sparse_results), dense_results(dense_results), gather_map(gather_map), - map(map), + set(set), row_bitmask(row_bitmask), stream(stream), mr(mr) @@ -340,8 +330,8 @@ class hash_compound_agg_finalizer final : public cudf::detail::aggregation_final rmm::exec_policy(stream), thrust::make_counting_iterator(0), col.size(), - ::cudf::detail::var_hash_functor>{ - map, row_bitmask, *var_result_view, *values_view, *sum_view, *count_view, agg._ddof}); + ::cudf::detail::var_hash_functor{ + set, row_bitmask, *var_result_view, *values_view, *sum_view, *count_view, agg._ddof}); sparse_results->add_result(col, agg, std::move(var_result)); dense_results->add_result(col, agg, to_dense_agg_result(agg)); } @@ -398,13 +388,13 @@ flatten_single_pass_aggs(host_span requests) * * @see groupby_null_templated() */ -template +template void sparse_to_dense_results(table_view const& keys, host_span requests, cudf::detail::result_cache* sparse_results, cudf::detail::result_cache* dense_results, device_span gather_map, - map_type const& map, + SetType set, bool keys_have_nulls, null_policy include_null_keys, rmm::cuda_stream_view stream, @@ -423,7 +413,7 @@ void sparse_to_dense_results(table_view const& keys, // Given an aggregation, this will get the result from sparse_results and // convert and return dense, compacted result auto finalizer = hash_compound_agg_finalizer( - col, sparse_results, dense_results, gather_map, map, row_bitmask_ptr, stream, mr); + col, sparse_results, dense_results, gather_map, set, row_bitmask_ptr, stream, mr); for (auto&& agg : agg_v) { agg->finalize(finalizer); } @@ -467,11 +457,11 @@ auto create_sparse_results_table(table_view const& flattened_values, * @brief Computes all aggregations from `requests` that require a single pass * over the data and stores the results in `sparse_results` */ -template +template void compute_single_pass_aggs(table_view const& keys, host_span requests, cudf::detail::result_cache* sparse_results, - map_type& map, + SetType set, bool keys_have_nulls, null_policy include_null_keys, rmm::cuda_stream_view stream) @@ -494,16 +484,16 @@ void compute_single_pass_aggs(table_view const& keys, ? cudf::detail::bitmask_and(keys, stream, rmm::mr::get_current_device_resource()).first : rmm::device_buffer{}; - thrust::for_each_n(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - keys.num_rows(), - hash::compute_single_pass_aggs_fn>{ - map, - *d_values, - *d_sparse_table, - d_aggs.data(), - static_cast(row_bitmask.data()), - skip_key_rows_with_nulls}); + thrust::for_each_n( + rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + keys.num_rows(), + hash::compute_single_pass_aggs_fn{set, + *d_values, + *d_sparse_table, + d_aggs.data(), + static_cast(row_bitmask.data()), + skip_key_rows_with_nulls}); // Add results back to sparse_results cache auto sparse_result_cols = sparse_table.release(); for (size_t i = 0; i < aggs.size(); i++) { @@ -517,23 +507,15 @@ void compute_single_pass_aggs(table_view const& keys, * @brief Computes and returns a device vector containing all populated keys in * `map`. */ -template -rmm::device_uvector extract_populated_keys(map_type const& map, +template +rmm::device_uvector extract_populated_keys(SetType const& key_set, size_type num_keys, rmm::cuda_stream_view stream) { rmm::device_uvector populated_keys(num_keys, stream); + auto const keys_end = key_set.retrieve_all(populated_keys.begin(), stream.value()); - auto const get_key = cuda::proclaim_return_type::key_type>( - [] __device__(auto const& element) { return element.first; }); // first = key - auto const key_used = [unused = map.get_unused_key()] __device__(auto key) { - return key != unused; - }; - auto const key_itr = thrust::make_transform_iterator(map.data(), get_key); - auto const end_it = cudf::detail::copy_if_safe( - key_itr, key_itr + map.capacity(), populated_keys.begin(), key_used, stream); - - populated_keys.resize(std::distance(populated_keys.begin(), end_it), stream); + populated_keys.resize(std::distance(populated_keys.begin(), keys_end), stream); return populated_keys; } @@ -580,30 +562,33 @@ std::unique_ptr groupby(table_view const& keys, auto const row_hash = cudf::experimental::row::hash::row_hasher{std::move(preprocessed_keys)}; auto const d_row_hash = row_hash.device_hasher(has_null); - size_type constexpr unused_key{std::numeric_limits::max()}; - size_type constexpr unused_value{std::numeric_limits::max()}; - // Cache of sparse results where the location of aggregate value in each - // column is indexed by the hash map + // column is indexed by the hash set cudf::detail::result_cache sparse_results(requests.size()); auto const comparator_helper = [&](auto const d_key_equal) { - using allocator_type = typename map_type::allocator_type; - - auto const map = map_type::create(compute_hash_table_size(num_keys), - stream, - unused_key, - unused_value, - d_row_hash, - d_key_equal, - allocator_type()); - // Compute all single pass aggs first - compute_single_pass_aggs( - keys, requests, &sparse_results, *map, keys_have_nulls, include_null_keys, stream); + auto const set = cuco::static_set{num_keys, + 0.5, // desired load factor + cuco::empty_key{cudf::detail::CUDF_SIZE_TYPE_SENTINEL}, + d_key_equal, + probing_scheme_type{d_row_hash}, + cuco::thread_scope_device, + cuco::storage<1>{}, + cudf::detail::cuco_allocator{stream}, + stream.value()}; - // Extract the populated indices from the hash map and create a gather map. + // Compute all single pass aggs first + compute_single_pass_aggs(keys, + requests, + &sparse_results, + set.ref(cuco::insert_and_find), + keys_have_nulls, + include_null_keys, + stream); + + // Extract the populated indices from the hash set and create a gather map. // Gathering using this map from sparse results will give dense results. - auto gather_map = extract_populated_keys(*map, keys.num_rows(), stream); + auto gather_map = extract_populated_keys(set, keys.num_rows(), stream); // Compact all results from sparse_results and insert into cache sparse_to_dense_results(keys, @@ -611,7 +596,7 @@ std::unique_ptr
groupby(table_view const& keys, &sparse_results, cache, gather_map, - *map, + set.ref(cuco::find), keys_have_nulls, include_null_keys, stream, diff --git a/cpp/src/groupby/hash/groupby_kernels.cuh b/cpp/src/groupby/hash/groupby_kernels.cuh index 4dfb191480b..9abfe22950a 100644 --- a/cpp/src/groupby/hash/groupby_kernels.cuh +++ b/cpp/src/groupby/hash/groupby_kernels.cuh @@ -30,42 +30,34 @@ namespace groupby { namespace detail { namespace hash { /** - * @brief Compute single-pass aggregations and store results into a sparse - * `output_values` table, and populate `map` with indices of unique keys + * @brief Computes single-pass aggregations and store results into a sparse `output_values` table, + * and populate `set` with indices of unique keys * - * The hash map is built by inserting every row `i` from the `keys` and - * `values` tables as a single (key,value) pair. When the pair is inserted, if - * the key was not already present in the map, then the corresponding value is - * simply copied to the output. If the key was already present in the map, - * then the inserted `values` row is aggregated with the existing row. This - * aggregation is done for every element `j` in the row by applying aggregation - * operation `j` between the new and existing element. + * The hash set is built by inserting every row index `i` from the `keys` and `values` tables. If + * the index was not present in the set, insert they index and then copy it to the output. If the + * key was already present in the set, then the inserted index is aggregated with the existing row. + * This aggregation is done for every element `j` in the row by applying aggregation operation `j` + * between the new and existing element. * * Instead of storing the entire rows from `input_keys` and `input_values` in - * the hashmap, we instead store the row indices. For example, when inserting - * row at index `i` from `input_keys` into the hash map, the value `i` is what - * gets stored for the hash map's "key". It is assumed the `map` was constructed + * the hashset, we instead store the row indices. For example, when inserting + * row at index `i` from `input_keys` into the hash set, the value `i` is what + * gets stored for the hash set's "key". It is assumed the `set` was constructed * with a custom comparator that uses these row indices to check for equality * between key rows. For example, comparing two keys `k0` and `k1` will compare * the two rows `input_keys[k0] ?= input_keys[k1]` * - * Likewise, we store the row indices for the hash maps "values". These indices - * index into the `output_values` table. For a given key `k` (which is an index - * into `input_keys`), the corresponding value `v` indexes into `output_values` - * and stores the result of aggregating rows from `input_values` from rows of - * `input_keys` equivalent to the row at `k`. - * * The exact size of the result is not known a priori, but can be upper bounded * by the number of rows in `input_keys` & `input_values`. Therefore, it is * assumed `output_values` has sufficient storage for an equivalent number of * rows. In this way, after all rows are aggregated, `output_values` will likely * be "sparse", meaning that not all rows contain the result of an aggregation. * - * @tparam Map The type of the hash map + * @tparam SetType The type of the hash set device ref */ -template +template struct compute_single_pass_aggs_fn { - Map map; + SetType set; table_device_view input_values; mutable_table_device_view output_values; aggregation::Kind const* __restrict__ aggs; @@ -75,9 +67,9 @@ struct compute_single_pass_aggs_fn { /** * @brief Construct a new compute_single_pass_aggs_fn functor object * - * @param map Hash map object to insert key,value pairs into. + * @param set_ref Hash set object to insert key,value pairs into. * @param input_values The table whose rows will be aggregated in the values - * of the hash map + * of the hash set * @param output_values Table that stores the results of aggregating rows of * `input_values`. * @param aggs The set of aggregation operations to perform across the @@ -88,13 +80,13 @@ struct compute_single_pass_aggs_fn { * null values should be skipped. It `true`, it is assumed `row_bitmask` is a * bitmask where bit `i` indicates the presence of a null value in row `i`. */ - compute_single_pass_aggs_fn(Map map, + compute_single_pass_aggs_fn(SetType set, table_device_view input_values, mutable_table_device_view output_values, aggregation::Kind const* aggs, bitmask_type const* row_bitmask, bool skip_rows_with_nulls) - : map(map), + : set(set), input_values(input_values), output_values(output_values), aggs(aggs), @@ -106,10 +98,9 @@ struct compute_single_pass_aggs_fn { __device__ void operator()(size_type i) { if (not skip_rows_with_nulls or cudf::bit_is_set(row_bitmask, i)) { - auto result = map.insert(thrust::make_pair(i, i)); + auto const result = set.insert_and_find(i); - cudf::detail::aggregate_row( - output_values, result.first->second, input_values, i, aggs); + cudf::detail::aggregate_row(output_values, *result.first, input_values, i, aggs); } } }; diff --git a/cpp/src/groupby/hash/multi_pass_kernels.cuh b/cpp/src/groupby/hash/multi_pass_kernels.cuh index 4bc73631732..7043eafdc10 100644 --- a/cpp/src/groupby/hash/multi_pass_kernels.cuh +++ b/cpp/src/groupby/hash/multi_pass_kernels.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,23 +31,23 @@ namespace cudf { namespace detail { -template +template struct var_hash_functor { - Map const map; + SetType set; bitmask_type const* __restrict__ row_bitmask; mutable_column_device_view target; column_device_view source; column_device_view sum; column_device_view count; size_type ddof; - var_hash_functor(Map const map, + var_hash_functor(SetType set, bitmask_type const* row_bitmask, mutable_column_device_view target, column_device_view source, column_device_view sum, column_device_view count, size_type ddof) - : map(map), + : set(set), row_bitmask(row_bitmask), target(target), source(source), @@ -96,8 +96,7 @@ struct var_hash_functor { __device__ inline void operator()(size_type source_index) { if (row_bitmask == nullptr or cudf::bit_is_set(row_bitmask, source_index)) { - auto result = map.find(source_index); - auto target_index = result->second; + auto const target_index = *set.find(source_index); auto col = source; auto source_type = source.type(); diff --git a/docs/cudf/source/user_guide/pandas-comparison.md b/docs/cudf/source/user_guide/pandas-comparison.md index 03ce58ea9e3..549d91b771a 100644 --- a/docs/cudf/source/user_guide/pandas-comparison.md +++ b/docs/cudf/source/user_guide/pandas-comparison.md @@ -87,7 +87,7 @@ using `.from_arrow()` or `.from_pandas()`. ## Result ordering -By default, `join` (or `merge`) and `groupby` operations in cuDF +By default, `join` (or `merge`), `value_counts` and `groupby` operations in cuDF do *not* guarantee output ordering. Compare the results obtained from Pandas and cuDF below: diff --git a/python/cudf/cudf/core/dataframe.py b/python/cudf/cudf/core/dataframe.py index 9b4a79c6841..a0e1a041342 100644 --- a/python/cudf/cudf/core/dataframe.py +++ b/python/cudf/cudf/core/dataframe.py @@ -7688,10 +7688,10 @@ def value_counts( dog 4 0 cat 4 0 ant 6 0 - >>> df.value_counts() + >>> df.value_counts().sort_index() num_legs num_wings - 4 0 2 2 2 1 + 4 0 2 6 0 1 Name: count, dtype: int64 """ diff --git a/python/cudf/cudf/core/groupby/groupby.py b/python/cudf/cudf/core/groupby/groupby.py index 9612349a607..e4370be304a 100644 --- a/python/cudf/cudf/core/groupby/groupby.py +++ b/python/cudf/cudf/core/groupby/groupby.py @@ -109,11 +109,11 @@ def _is_row_of(chunk, obj): Parrot 30.0 Parrot 20.0 Name: Max Speed, dtype: float64 ->>> ser.groupby(level=0).mean() +>>> ser.groupby(level=0, sort=True).mean() Falcon 370.0 Parrot 25.0 Name: Max Speed, dtype: float64 ->>> ser.groupby(ser > 100).mean() +>>> ser.groupby(ser > 100, sort=True).mean() Max Speed False 25.0 True 370.0 @@ -133,7 +133,7 @@ def _is_row_of(chunk, obj): 1 Falcon 370.0 2 Parrot 24.0 3 Parrot 26.0 ->>> df.groupby(['Animal']).mean() +>>> df.groupby(['Animal'], sort=True).mean() Max Speed Animal Falcon 375.0 @@ -151,22 +151,22 @@ def _is_row_of(chunk, obj): Wild 350.0 Parrot Captive 30.0 Wild 20.0 ->>> df.groupby(level=0).mean() +>>> df.groupby(level=0, sort=True).mean() Max Speed Animal Falcon 370.0 Parrot 25.0 ->>> df.groupby(level="Type").mean() +>>> df.groupby(level="Type", sort=True).mean() Max Speed Type -Wild 185.0 Captive 210.0 +Wild 185.0 >>> df = cudf.DataFrame({{'A': 'a a b'.split(), ... 'B': [1,2,3], ... 'C': [4,6,5]}}) ->>> g1 = df.groupby('A', group_keys=False) ->>> g2 = df.groupby('A', group_keys=True) +>>> g1 = df.groupby('A', group_keys=False, sort=True) +>>> g2 = df.groupby('A', group_keys=True, sort=True) Notice that ``g1`` have ``g2`` have two groups, ``a`` and ``b``, and only differ in their ``group_keys`` argument. Calling `apply` in various ways, @@ -539,11 +539,11 @@ def agg(self, func): ... 'b': [1, 2, 3], ... 'c': [2, 2, 1] ... }) - >>> a.groupby('a').agg('sum') + >>> a.groupby('a', sort=True).agg('sum') b c a - 2 3 1 1 3 4 + 2 3 1 Specifying a list of aggregations to perform on each column. @@ -553,12 +553,12 @@ def agg(self, func): ... 'b': [1, 2, 3], ... 'c': [2, 2, 1] ... }) - >>> a.groupby('a').agg(['sum', 'min']) + >>> a.groupby('a', sort=True).agg(['sum', 'min']) b c sum min sum min a - 2 3 3 1 1 1 3 1 4 2 + 2 3 3 1 1 Using a dict to specify aggregations to perform per column. @@ -568,12 +568,12 @@ def agg(self, func): ... 'b': [1, 2, 3], ... 'c': [2, 2, 1] ... }) - >>> a.groupby('a').agg({'a': 'max', 'b': ['min', 'mean']}) + >>> a.groupby('a', sort=True).agg({'a': 'max', 'b': ['min', 'mean']}) a b max min mean a - 2 2 3 3.0 1 1 1 1.5 + 2 2 3 3.0 Using lambdas/callables to specify aggregations taking parameters. diff --git a/python/cudf/cudf/tests/test_groupby.py b/python/cudf/cudf/tests/test_groupby.py index 63e0cf98b27..f856bbedca2 100644 --- a/python/cudf/cudf/tests/test_groupby.py +++ b/python/cudf/cudf/tests/test_groupby.py @@ -55,12 +55,12 @@ def assert_groupby_results_equal( if isinstance(expect, (pd.DataFrame, cudf.DataFrame)): expect = expect.sort_values(by=by).reset_index(drop=True) else: - expect = expect.sort_values().reset_index(drop=True) + expect = expect.sort_values(by=by).reset_index(drop=True) if isinstance(got, cudf.DataFrame): got = got.sort_values(by=by).reset_index(drop=True) else: - got = got.sort_values().reset_index(drop=True) + got = got.sort_values(by=by).reset_index(drop=True) assert_eq(expect, got, **kwargs) @@ -179,7 +179,7 @@ def test_groupby_agg_min_max_dictlist(nelem): def test_groupby_as_index_single_agg(pdf, gdf, as_index): gdf = gdf.groupby("y", as_index=as_index).agg({"x": "mean"}) pdf = pdf.groupby("y", as_index=as_index).agg({"x": "mean"}) - assert_groupby_results_equal(pdf, gdf) + assert_groupby_results_equal(pdf, gdf, as_index=as_index, by="y") @pytest.mark.parametrize("engine", ["cudf", "jit"]) @@ -190,7 +190,7 @@ def test_groupby_as_index_apply(pdf, gdf, as_index, engine): ) kwargs = {"func": lambda df: df["x"].mean(), "include_groups": False} pdf = pdf.groupby("y", as_index=as_index).apply(**kwargs) - assert_groupby_results_equal(pdf, gdf) + assert_groupby_results_equal(pdf, gdf, as_index=as_index, by="y") @pytest.mark.parametrize("as_index", [True, False]) @@ -3714,7 +3714,13 @@ def test_group_by_value_counts(normalize, sort, ascending, dropna, as_index): # TODO: Remove `check_names=False` once testing against `pandas>=2.0.0` assert_groupby_results_equal( - actual, expected, check_names=False, check_index_type=False + actual, + expected, + check_names=False, + check_index_type=False, + as_index=as_index, + by=["gender", "education"], + sort=sort, )