Skip to content

Commit

Permalink
REF: share GroupBy.transform (#41308)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored May 4, 2021
1 parent 88ce933 commit a997bab
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 68 deletions.
78 changes: 13 additions & 65 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,35 +526,9 @@ def _aggregate_named(self, func, *args, **kwargs):
@Substitution(klass="Series")
@Appender(_transform_template)
def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):

if maybe_use_numba(engine):
with group_selection_context(self):
data = self._selected_obj
result = self._transform_with_numba(
data.to_frame(), func, *args, engine_kwargs=engine_kwargs, **kwargs
)
return self.obj._constructor(
result.ravel(), index=data.index, name=data.name
)

func = com.get_cython_func(func) or func

if not isinstance(func, str):
return self._transform_general(func, *args, **kwargs)

elif func not in base.transform_kernel_allowlist:
msg = f"'{func}' is not a valid function name for transform(name)"
raise ValueError(msg)
elif func in base.cythonized_kernels or func in base.transformation_kernels:
# cythonized transform or canned "agg+broadcast"
return getattr(self, func)(*args, **kwargs)
# If func is a reduction, we need to broadcast the
# result to the whole group. Compute func result
# and deal with possible broadcasting below.
# Temporarily set observed for dealing with categoricals.
with com.temp_setattr(self, "observed", True):
result = getattr(self, func)(*args, **kwargs)
return self._wrap_transform_fast_result(result)
return self._transform(
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
)

def _transform_general(self, func: Callable, *args, **kwargs) -> Series:
"""
Expand Down Expand Up @@ -586,6 +560,9 @@ def _transform_general(self, func: Callable, *args, **kwargs) -> Series:
result.name = self._selected_obj.name
return result

def _can_use_transform_fast(self, result) -> bool:
return True

def _wrap_transform_fast_result(self, result: Series) -> Series:
"""
fast version of transform, only applicable to
Expand Down Expand Up @@ -1334,43 +1311,14 @@ def _transform_general(self, func, *args, **kwargs):
@Substitution(klass="DataFrame")
@Appender(_transform_template)
def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
return self._transform(
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
)

if maybe_use_numba(engine):
with group_selection_context(self):
data = self._selected_obj
result = self._transform_with_numba(
data, func, *args, engine_kwargs=engine_kwargs, **kwargs
)
return self.obj._constructor(result, index=data.index, columns=data.columns)

# optimized transforms
func = com.get_cython_func(func) or func

if not isinstance(func, str):
return self._transform_general(func, *args, **kwargs)

elif func not in base.transform_kernel_allowlist:
msg = f"'{func}' is not a valid function name for transform(name)"
raise ValueError(msg)
elif func in base.cythonized_kernels or func in base.transformation_kernels:
# cythonized transformation or canned "reduction+broadcast"
return getattr(self, func)(*args, **kwargs)
# GH 30918
# Use _transform_fast only when we know func is an aggregation
if func in base.reduction_kernels:
# If func is a reduction, we need to broadcast the
# result to the whole group. Compute func result
# and deal with possible broadcasting below.
# Temporarily set observed for dealing with categoricals.
with com.temp_setattr(self, "observed", True):
result = getattr(self, func)(*args, **kwargs)

if isinstance(result, DataFrame) and result.columns.equals(
self._obj_with_exclusions.columns
):
return self._wrap_transform_fast_result(result)

return self._transform_general(func, *args, **kwargs)
def _can_use_transform_fast(self, result) -> bool:
return isinstance(result, DataFrame) and result.columns.equals(
self._obj_with_exclusions.columns
)

def _wrap_transform_fast_result(self, result: DataFrame) -> DataFrame:
"""
Expand Down
57 changes: 54 additions & 3 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class providing the base-class of operations.
Sequence,
TypeVar,
Union,
cast,
)

import numpy as np
Expand Down Expand Up @@ -104,7 +105,10 @@ class providing the base-class of operations.
from pandas.core.internals.blocks import ensure_block_shape
from pandas.core.series import Series
from pandas.core.sorting import get_group_index_sorter
from pandas.core.util.numba_ import NUMBA_FUNC_CACHE
from pandas.core.util.numba_ import (
NUMBA_FUNC_CACHE,
maybe_use_numba,
)

if TYPE_CHECKING:
from typing import Literal
Expand Down Expand Up @@ -1398,8 +1402,55 @@ def _cython_transform(

return self._wrap_transformed_output(output)

def transform(self, func, *args, **kwargs):
raise AbstractMethodError(self)
@final
def _transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):

if maybe_use_numba(engine):
# TODO: tests with self._selected_obj.ndim == 1 on DataFrameGroupBy
with group_selection_context(self):
data = self._selected_obj
df = data if data.ndim == 2 else data.to_frame()
result = self._transform_with_numba(
df, func, *args, engine_kwargs=engine_kwargs, **kwargs
)
if self.obj.ndim == 2:
return cast(DataFrame, self.obj)._constructor(
result, index=data.index, columns=data.columns
)
else:
return cast(Series, self.obj)._constructor(
result.ravel(), index=data.index, name=data.name
)

# optimized transforms
func = com.get_cython_func(func) or func

if not isinstance(func, str):
return self._transform_general(func, *args, **kwargs)

elif func not in base.transform_kernel_allowlist:
msg = f"'{func}' is not a valid function name for transform(name)"
raise ValueError(msg)
elif func in base.cythonized_kernels or func in base.transformation_kernels:
# cythonized transform or canned "agg+broadcast"
return getattr(self, func)(*args, **kwargs)

else:
# i.e. func in base.reduction_kernels

# GH#30918 Use _transform_fast only when we know func is an aggregation
# If func is a reduction, we need to broadcast the
# result to the whole group. Compute func result
# and deal with possible broadcasting below.
# Temporarily set observed for dealing with categoricals.
with com.temp_setattr(self, "observed", True):
result = getattr(self, func)(*args, **kwargs)

if self._can_use_transform_fast(result):
return self._wrap_transform_fast_result(result)

# only reached for DataFrameGroupBy
return self._transform_general(func, *args, **kwargs)

# -----------------------------------------------------------------
# Utilities
Expand Down

0 comments on commit a997bab

Please sign in to comment.