Skip to content

Commit

Permalink
FIX-modin-project#2269: Added required arguments for groupby_agg
Browse files Browse the repository at this point in the history
Moved wrap_udf_function into backend because omnisci doesn't support
executing lambdas.

Signed-off-by: Gregory Shimansky <[email protected]>
  • Loading branch information
gshimansky authored and YarShev committed Oct 30, 2020
1 parent 657916f commit a0e1b74
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 5 deletions.
20 changes: 18 additions & 2 deletions modin/backends/base/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1396,14 +1396,30 @@ def groupby_size(
drop=drop,
)

def groupby_agg(self, by, axis, agg_func, groupby_args, agg_args, drop=False):
def groupby_agg(
self,
by,
is_multi_by,
idx_name,
axis,
agg_func,
agg_args,
agg_kwargs,
groupby_kwargs,
drop_,
drop=False,
):
return GroupByDefault.register(pandas.core.groupby.DataFrameGroupBy.aggregate)(
self,
by=by,
is_multi_by=is_multi_by,
idx_name=idx_name,
axis=axis,
agg_func=agg_func,
groupby_args=groupby_args,
agg_args=agg_args,
agg_kwargs=agg_kwargs,
groupby_kwargs=groupby_kwargs,
drop_=drop_,
drop=drop,
)

Expand Down
2 changes: 2 additions & 0 deletions modin/backends/pandas/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2596,6 +2596,8 @@ def groupby_agg(
drop_,
drop=False,
):
agg_func = wrap_udf_function(agg_func)

if is_multi_by:
return self.default_to_pandas(agg_func, *agg_args, **agg_kwargs)

Expand Down
18 changes: 18 additions & 0 deletions modin/experimental/backends/omnisci/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,24 @@ def groupby_count(self, by, axis, groupby_args, map_args, **kwargs):
)
return self.__constructor__(new_frame)

def groupby_agg(self,
by,
is_multi_by,
idx_name,
axis,
agg_func,
agg_args,
agg_kwargs,
groupby_kwargs,
drop_,
drop=False,
):
# TODO: handle is_multi_by, idx_name and drop args
new_frame = self._modin_frame.groupby_agg(
by, axis, agg_func, groupby_kwargs, *agg_args, **agg_kwargs
)
return self.__constructor__(new_frame)

def groupby_dict_agg(self, by, func_dict, groupby_args, agg_args, drop=False):
"""Apply aggregation functions to a grouped dataframe per-column.
Expand Down
10 changes: 10 additions & 0 deletions modin/experimental/engines/omnisci_on_ray/test/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,7 @@ def groupby_count(df, cols, as_index, **kwargs):

run_and_compare(groupby_count, data=self.data, cols=cols, as_index=as_index)

@pytest.mask.xfail(reason="Currently mean() passes a lambda into backend which cannot be executed on omnisci backend")
@pytest.mark.parametrize("cols", cols_value)
@pytest.mark.parametrize("as_index", bool_arg_values)
def test_groupby_mean(self, cols, as_index):
Expand Down Expand Up @@ -577,6 +578,15 @@ def groupby(df, **kwargs):

run_and_compare(groupby, data=self.data)

@pytest.mask.xfail(reason="Function specified as a string should be passed into backend API, but currently it is transformed into a lambda")
@pytest.mark.parametrize("cols", cols_value)
@pytest.mark.parametrize("as_index", bool_arg_values)
def test_groupby_agg_mean(self, cols, as_index):
def groupby_mean(df, cols, as_index, **kwargs):
return df.groupby(cols, as_index=as_index).agg("mean")

run_and_compare(groupby_mean, data=self.data, cols=cols, as_index=as_index)

taxi_data = {
"a": [1, 1, 2, 2],
"b": [11, 21, 12, 11],
Expand Down
4 changes: 1 addition & 3 deletions modin/pandas/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import pandas.core.common as com

from modin.error_message import ErrorMessage
from modin.utils import _inherit_docstrings, wrap_udf_function, try_cast_to_pandas
from modin.utils import _inherit_docstrings, try_cast_to_pandas
from modin.config import IsExperimental
from .series import Series

Expand Down Expand Up @@ -834,8 +834,6 @@ def _apply_agg_function(self, f, drop=True, *args, **kwargs):
"""
assert callable(f), "'{0}' object is not callable".format(type(f))

f = wrap_udf_function(f)

new_manager = self._query_compiler.groupby_agg(
by=self._by,
is_multi_by=self._is_multi_by,
Expand Down

0 comments on commit a0e1b74

Please sign in to comment.