diff --git a/modin/pandas/base.py b/modin/pandas/base.py index 5eac99d2994..740da1b6801 100644 --- a/modin/pandas/base.py +++ b/modin/pandas/base.py @@ -517,22 +517,22 @@ def aggregate(self, func=None, axis=0, *args, **kwargs): agg = aggregate - def _aggregate(self, arg, *args, **kwargs): + def _aggregate(self, func, *args, **kwargs): _axis = kwargs.pop("_axis", 0) kwargs.pop("_level", None) - if isinstance(arg, str): + if isinstance(func, str): kwargs.pop("is_transform", None) - return self._string_function(arg, *args, **kwargs) + return self._string_function(func, *args, **kwargs) # Dictionaries have complex behavior because they can be renamed here. - elif isinstance(arg, dict): - return self._default_to_pandas("agg", arg, *args, **kwargs) - elif is_list_like(arg) or callable(arg): + elif func is None or isinstance(func, dict): + return self._default_to_pandas("agg", func, *args, **kwargs) + elif is_list_like(func) or callable(func): kwargs.pop("is_transform", None) - return self.apply(arg, axis=_axis, args=args, **kwargs) + return self.apply(func, axis=_axis, args=args, **kwargs) else: - raise TypeError("type {} is not callable".format(type(arg))) + raise TypeError("type {} is not callable".format(type(func))) def _string_function(self, func, *args, **kwargs): assert isinstance(func, str) diff --git a/modin/pandas/test/dataframe/test_udf.py b/modin/pandas/test/dataframe/test_udf.py index 23567dc1345..1278c4f036c 100644 --- a/modin/pandas/test/dataframe/test_udf.py +++ b/modin/pandas/test/dataframe/test_udf.py @@ -42,6 +42,18 @@ matplotlib.use("Agg") +def test_agg_dict(): + md_df, pd_df = create_test_dfs(test_data_values[0]) + agg_dict = {pd_df.columns[0]: "sum", pd_df.columns[-1]: ("sum", "count")} + eval_general(md_df, pd_df, lambda df: df.agg(agg_dict), raising_exceptions=True) + + agg_dict = { + "new_col1": (pd_df.columns[0], "sum"), + "new_col2": (pd_df.columns[-1], "count"), + } + eval_general(md_df, pd_df, lambda df: df.agg(**agg_dict), raising_exceptions=True) + + @pytest.mark.parametrize("axis", [0, 1]) @pytest.mark.parametrize( "func",