From a997bab40ca9b4f51463a29ba05aa8c5ff2e03f4 Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Tue, 4 May 2021 14:36:53 -0700 Subject: [PATCH] REF: share GroupBy.transform (#41308) --- pandas/core/groupby/generic.py | 78 ++++++---------------------------- pandas/core/groupby/groupby.py | 57 +++++++++++++++++++++++-- 2 files changed, 67 insertions(+), 68 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 8e5f773c1a055..18506b871bda6 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -526,35 +526,9 @@ def _aggregate_named(self, func, *args, **kwargs): @Substitution(klass="Series") @Appender(_transform_template) def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs): - - if maybe_use_numba(engine): - with group_selection_context(self): - data = self._selected_obj - result = self._transform_with_numba( - data.to_frame(), func, *args, engine_kwargs=engine_kwargs, **kwargs - ) - return self.obj._constructor( - result.ravel(), index=data.index, name=data.name - ) - - func = com.get_cython_func(func) or func - - if not isinstance(func, str): - return self._transform_general(func, *args, **kwargs) - - elif func not in base.transform_kernel_allowlist: - msg = f"'{func}' is not a valid function name for transform(name)" - raise ValueError(msg) - elif func in base.cythonized_kernels or func in base.transformation_kernels: - # cythonized transform or canned "agg+broadcast" - return getattr(self, func)(*args, **kwargs) - # If func is a reduction, we need to broadcast the - # result to the whole group. Compute func result - # and deal with possible broadcasting below. - # Temporarily set observed for dealing with categoricals. - with com.temp_setattr(self, "observed", True): - result = getattr(self, func)(*args, **kwargs) - return self._wrap_transform_fast_result(result) + return self._transform( + func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs + ) def _transform_general(self, func: Callable, *args, **kwargs) -> Series: """ @@ -586,6 +560,9 @@ def _transform_general(self, func: Callable, *args, **kwargs) -> Series: result.name = self._selected_obj.name return result + def _can_use_transform_fast(self, result) -> bool: + return True + def _wrap_transform_fast_result(self, result: Series) -> Series: """ fast version of transform, only applicable to @@ -1334,43 +1311,14 @@ def _transform_general(self, func, *args, **kwargs): @Substitution(klass="DataFrame") @Appender(_transform_template) def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs): + return self._transform( + func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs + ) - if maybe_use_numba(engine): - with group_selection_context(self): - data = self._selected_obj - result = self._transform_with_numba( - data, func, *args, engine_kwargs=engine_kwargs, **kwargs - ) - return self.obj._constructor(result, index=data.index, columns=data.columns) - - # optimized transforms - func = com.get_cython_func(func) or func - - if not isinstance(func, str): - return self._transform_general(func, *args, **kwargs) - - elif func not in base.transform_kernel_allowlist: - msg = f"'{func}' is not a valid function name for transform(name)" - raise ValueError(msg) - elif func in base.cythonized_kernels or func in base.transformation_kernels: - # cythonized transformation or canned "reduction+broadcast" - return getattr(self, func)(*args, **kwargs) - # GH 30918 - # Use _transform_fast only when we know func is an aggregation - if func in base.reduction_kernels: - # If func is a reduction, we need to broadcast the - # result to the whole group. Compute func result - # and deal with possible broadcasting below. - # Temporarily set observed for dealing with categoricals. - with com.temp_setattr(self, "observed", True): - result = getattr(self, func)(*args, **kwargs) - - if isinstance(result, DataFrame) and result.columns.equals( - self._obj_with_exclusions.columns - ): - return self._wrap_transform_fast_result(result) - - return self._transform_general(func, *args, **kwargs) + def _can_use_transform_fast(self, result) -> bool: + return isinstance(result, DataFrame) and result.columns.equals( + self._obj_with_exclusions.columns + ) def _wrap_transform_fast_result(self, result: DataFrame) -> DataFrame: """ diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index f2041951b9e49..7a8b41fbdf141 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -29,6 +29,7 @@ class providing the base-class of operations. Sequence, TypeVar, Union, + cast, ) import numpy as np @@ -104,7 +105,10 @@ class providing the base-class of operations. from pandas.core.internals.blocks import ensure_block_shape from pandas.core.series import Series from pandas.core.sorting import get_group_index_sorter -from pandas.core.util.numba_ import NUMBA_FUNC_CACHE +from pandas.core.util.numba_ import ( + NUMBA_FUNC_CACHE, + maybe_use_numba, +) if TYPE_CHECKING: from typing import Literal @@ -1398,8 +1402,55 @@ def _cython_transform( return self._wrap_transformed_output(output) - def transform(self, func, *args, **kwargs): - raise AbstractMethodError(self) + @final + def _transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs): + + if maybe_use_numba(engine): + # TODO: tests with self._selected_obj.ndim == 1 on DataFrameGroupBy + with group_selection_context(self): + data = self._selected_obj + df = data if data.ndim == 2 else data.to_frame() + result = self._transform_with_numba( + df, func, *args, engine_kwargs=engine_kwargs, **kwargs + ) + if self.obj.ndim == 2: + return cast(DataFrame, self.obj)._constructor( + result, index=data.index, columns=data.columns + ) + else: + return cast(Series, self.obj)._constructor( + result.ravel(), index=data.index, name=data.name + ) + + # optimized transforms + func = com.get_cython_func(func) or func + + if not isinstance(func, str): + return self._transform_general(func, *args, **kwargs) + + elif func not in base.transform_kernel_allowlist: + msg = f"'{func}' is not a valid function name for transform(name)" + raise ValueError(msg) + elif func in base.cythonized_kernels or func in base.transformation_kernels: + # cythonized transform or canned "agg+broadcast" + return getattr(self, func)(*args, **kwargs) + + else: + # i.e. func in base.reduction_kernels + + # GH#30918 Use _transform_fast only when we know func is an aggregation + # If func is a reduction, we need to broadcast the + # result to the whole group. Compute func result + # and deal with possible broadcasting below. + # Temporarily set observed for dealing with categoricals. + with com.temp_setattr(self, "observed", True): + result = getattr(self, func)(*args, **kwargs) + + if self._can_use_transform_fast(result): + return self._wrap_transform_fast_result(result) + + # only reached for DataFrameGroupBy + return self._transform_general(func, *args, **kwargs) # ----------------------------------------------------------------- # Utilities