diff --git a/modin/backends/pandas/query_compiler.py b/modin/backends/pandas/query_compiler.py index 521e013c558..065bf38ac2b 100644 --- a/modin/backends/pandas/query_compiler.py +++ b/modin/backends/pandas/query_compiler.py @@ -650,8 +650,6 @@ def is_monotonic_increasing(df): return self.default_to_pandas(is_monotonic_increasing) count = MapReduceFunction.register(pandas.DataFrame.count, pandas.DataFrame.sum) - max = MapReduceFunction.register(pandas.DataFrame.max) - min = MapReduceFunction.register(pandas.DataFrame.min) sum = MapReduceFunction.register(pandas.DataFrame.sum) prod = MapReduceFunction.register(pandas.DataFrame.prod) any = MapReduceFunction.register(pandas.DataFrame.any, pandas.DataFrame.any) @@ -662,6 +660,34 @@ def is_monotonic_increasing(df): axis=0, ) + def max(self, axis, **kwargs): + def map_func(df, **kwargs): + return pandas.DataFrame.max(df, **kwargs) + + def reduce_func(df, **kwargs): + if "numeric_only" in kwargs.keys() and kwargs["numeric_only"]: + kwargs = kwargs.copy() + kwargs["numeric_only"] = not kwargs["numeric_only"] + return pandas.DataFrame.max(df, **kwargs) + + return MapReduceFunction.register(map_func, reduce_func)( + self, axis=axis, **kwargs + ) + + def min(self, axis, **kwargs): + def map_func(df, **kwargs): + return pandas.DataFrame.min(df, **kwargs) + + def reduce_func(df, **kwargs): + if "numeric_only" in kwargs.keys() and kwargs["numeric_only"]: + kwargs = kwargs.copy() + kwargs["numeric_only"] = not kwargs["numeric_only"] + return pandas.DataFrame.min(df, **kwargs) + + return MapReduceFunction.register(map_func, reduce_func)( + self, axis=axis, **kwargs + ) + def mean(self, axis, **kwargs): if kwargs.get("level") is not None: return self.default_to_pandas(pandas.DataFrame.mean, axis=axis, **kwargs) diff --git a/modin/data_management/functions/mapreducefunction.py b/modin/data_management/functions/mapreducefunction.py index 35669e1fb20..e76426d4f7d 100644 --- a/modin/data_management/functions/mapreducefunction.py +++ b/modin/data_management/functions/mapreducefunction.py @@ -20,20 +20,11 @@ def call(cls, map_function, reduce_function, **call_kwds): def caller(query_compiler, *args, **kwargs): preserve_index = call_kwds.pop("preserve_index", True) axis = call_kwds.get("axis", kwargs.get("axis")) - if kwargs.get("numeric_only", "is_not_exist") == "is_not_exist": - kwargs_for_reduce = kwargs - else: - kwargs_for_reduce = kwargs.copy() - if kwargs_for_reduce["numeric_only"]: - kwargs_for_reduce["numeric_only"] = not kwargs_for_reduce[ - "numeric_only" - ] - return query_compiler.__constructor__( query_compiler._modin_frame._map_reduce( cls.validate_axis(axis), lambda x: map_function(x, *args, **kwargs), - lambda y: reduce_function(y, *args, **kwargs_for_reduce), + lambda y: reduce_function(y, *args, **kwargs), preserve_index=preserve_index, ) )