Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: avoid merge in pandas groupby #1638

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions narwhals/_pandas_like/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from narwhals._expression_parsing import is_simple_aggregation
from narwhals._expression_parsing import parse_into_exprs
from narwhals._pandas_like.utils import horizontal_concat
from narwhals._pandas_like.utils import native_series_from_iterable
from narwhals._pandas_like.utils import select_columns_by_name
from narwhals.utils import Implementation
Expand Down Expand Up @@ -236,18 +237,11 @@ def agg_pandas( # noqa: PLR0915
new_names = [new_names[i] for i in index_map]
result_simple_aggs.columns = new_names

# Keep inplace=True to avoid making a redundant copy.
# This may need updating, depending on https://github.com/pandas-dev/pandas/pull/51466/files
result_simple_aggs.reset_index(inplace=True) # noqa: PD002

if nunique_aggs:
result_nunique_aggs = grouped[list(nunique_aggs.values())].nunique(
dropna=False
)
result_nunique_aggs.columns = list(nunique_aggs.keys())
# Keep inplace=True to avoid making a redundant copy.
# This may need updating, depending on https://github.com/pandas-dev/pandas/pull/51466/files
result_nunique_aggs.reset_index(inplace=True) # noqa: PD002
if simple_aggs and nunique_aggs:
if (
set(result_simple_aggs.columns)
Expand All @@ -259,7 +253,11 @@ def agg_pandas( # noqa: PLR0915
"that aggregations have unique output names."
)
raise ValueError(msg)
result_aggs = result_simple_aggs.merge(result_nunique_aggs, on=keys)
result_aggs = horizontal_concat(
[result_simple_aggs, result_nunique_aggs],
implementation=implementation,
backend_version=backend_version,
)
elif nunique_aggs and not simple_aggs:
result_aggs = result_nunique_aggs
elif simple_aggs and not nunique_aggs:
Expand All @@ -269,6 +267,9 @@ def agg_pandas( # noqa: PLR0915
result_aggs = native_namespace.DataFrame(
list(grouped.groups.keys()), columns=keys
)
# Keep inplace=True to avoid making a redundant copy.
# This may need updating, depending on https://github.com/pandas-dev/pandas/pull/51466/files
result_aggs.reset_index(inplace=True) # noqa: PD002
return from_dataframe(
select_columns_by_name(
result_aggs, output_names, backend_version, implementation
Expand Down
Loading