diff --git a/modin/backends/pandas/query_compiler.py b/modin/backends/pandas/query_compiler.py index 26d85231645..aa07a223120 100644 --- a/modin/backends/pandas/query_compiler.py +++ b/modin/backends/pandas/query_compiler.py @@ -2565,20 +2565,17 @@ def groupby_agg_builder(df, by=None, drop=False, partition_idx=None): missmatched_cols = pandas.Index([]) if by is not None: internal_by_df = by[internal_by] + if isinstance(internal_by_df, pandas.Series): internal_by_df = internal_by_df.to_frame() - if isinstance(internal_by_df, pandas.DataFrame): - missmatched_cols = internal_by_df.columns.difference(df.columns) - df = pandas.concat( - [df, internal_by_df[missmatched_cols]], - axis=1, - ) - internal_by_cols = internal_by_df.columns - elif is_multi_by: - internal_by_cols = pandas.Index([internal_by_df.name]) - else: - internal_by_cols = pandas.Index([internal_by_df]) + missmatched_cols = internal_by_df.columns.difference(df.columns) + df = pandas.concat( + [df, internal_by_df[missmatched_cols]], + axis=1, + copy=False, + ) + internal_by_cols = internal_by_df.columns external_by = by.columns.difference(internal_by) external_by_df = by[external_by].squeeze(axis=1)