Skip to content

Commit

Permalink
REG: DataFrame/Series.transform with list and non-list dict values (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
rhshadrach authored Feb 27, 2021
1 parent 11afc76 commit 47c6d16
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 23 deletions.
2 changes: 2 additions & 0 deletions doc/source/whatsnew/v1.2.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ Fixed regressions
Passing ``ascending=None`` is still considered invalid,
and the new error message suggests a proper usage
(``ascending`` must be a boolean or a list-like boolean).
- Fixed regression in :meth:`DataFrame.transform` and :meth:`Series.transform` giving incorrect column labels when passed a dictionary with a mix of list and non-list values (:issue:`40018`)
-

.. ---------------------------------------------------------------------------
Expand Down
49 changes: 26 additions & 23 deletions pandas/core/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def transform_dict_like(self, func):
if len(func) == 0:
raise ValueError("No transform functions were provided")

self.validate_dictlike_arg("transform", obj, func)
func = self.normalize_dictlike_arg("transform", obj, func)

results: Dict[Hashable, FrameOrSeriesUnion] = {}
for name, how in func.items():
Expand Down Expand Up @@ -405,32 +405,17 @@ def agg_dict_like(self, _axis: int) -> FrameOrSeriesUnion:
-------
Result of aggregation.
"""
from pandas.core.reshape.concat import concat

obj = self.obj
arg = cast(AggFuncTypeDict, self.f)

is_aggregator = lambda x: isinstance(x, (list, tuple, dict))

if _axis != 0: # pragma: no cover
raise ValueError("Can only pass dict with axis=0")

selected_obj = obj._selected_obj

self.validate_dictlike_arg("agg", selected_obj, arg)

# if we have a dict of any non-scalars
# eg. {'A' : ['mean']}, normalize all to
# be list-likes
# Cannot use arg.values() because arg may be a Series
if any(is_aggregator(x) for _, x in arg.items()):
new_arg: AggFuncTypeDict = {}
for k, v in arg.items():
if not isinstance(v, (tuple, list, dict)):
new_arg[k] = [v]
else:
new_arg[k] = v
arg = new_arg

from pandas.core.reshape.concat import concat
arg = self.normalize_dictlike_arg("agg", selected_obj, arg)

if selected_obj.ndim == 1:
# key only used for output
Expand Down Expand Up @@ -524,14 +509,15 @@ def maybe_apply_multiple(self) -> Optional[FrameOrSeriesUnion]:
return None
return self.obj.aggregate(self.f, self.axis, *self.args, **self.kwargs)

def validate_dictlike_arg(
def normalize_dictlike_arg(
self, how: str, obj: FrameOrSeriesUnion, func: AggFuncTypeDict
) -> None:
) -> AggFuncTypeDict:
"""
Raise if dict-like argument is invalid.
Handler for dict-like argument.
Ensures that necessary columns exist if obj is a DataFrame, and
that a nested renamer is not passed.
that a nested renamer is not passed. Also normalizes to all lists
when values consists of a mix of list and non-lists.
"""
assert how in ("apply", "agg", "transform")

Expand All @@ -551,6 +537,23 @@ def validate_dictlike_arg(
cols_sorted = list(safe_sort(list(cols)))
raise KeyError(f"Column(s) {cols_sorted} do not exist")

is_aggregator = lambda x: isinstance(x, (list, tuple, dict))

# if we have a dict of any non-scalars
# eg. {'A' : ['mean']}, normalize all to
# be list-likes
# Cannot use func.values() because arg may be a Series
if any(is_aggregator(x) for _, x in func.items()):
new_func: AggFuncTypeDict = {}
for k, v in func.items():
if not is_aggregator(v):
# mypy can't realize v is not a list here
new_func[k] = [v] # type:ignore[list-item]
else:
new_func[k] = v
func = new_func
return func


class FrameApply(Apply):
obj: DataFrame
Expand Down
11 changes: 11 additions & 0 deletions pandas/tests/apply/test_frame_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,17 @@ def test_transform_dictlike(axis, float_frame, box):
tm.assert_frame_equal(result, expected)


def test_transform_dictlike_mixed():
# GH 40018 - mix of lists and non-lists in values of a dictionary
df = DataFrame({"a": [1, 2], "b": [1, 4], "c": [1, 4]})
result = df.transform({"b": ["sqrt", "abs"], "c": "sqrt"})
expected = DataFrame(
[[1.0, 1, 1.0], [2.0, 4, 2.0]],
columns=MultiIndex([("b", "c"), ("sqrt", "abs")], [(0, 0, 1), (0, 1, 0)]),
)
tm.assert_frame_equal(result, expected)


@pytest.mark.parametrize(
"ops",
[
Expand Down
13 changes: 13 additions & 0 deletions pandas/tests/apply/test_series_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import pytest

from pandas import (
DataFrame,
MultiIndex,
Series,
concat,
)
Expand Down Expand Up @@ -55,6 +57,17 @@ def test_transform_dictlike(string_series, box):
tm.assert_frame_equal(result, expected)


def test_transform_dictlike_mixed():
# GH 40018 - mix of lists and non-lists in values of a dictionary
df = Series([1, 4])
result = df.transform({"b": ["sqrt", "abs"], "c": "sqrt"})
expected = DataFrame(
[[1.0, 1, 1.0], [2.0, 4, 2.0]],
columns=MultiIndex([("b", "c"), ("sqrt", "abs")], [(0, 0, 1), (0, 1, 0)]),
)
tm.assert_frame_equal(result, expected)


def test_transform_wont_agg(string_series):
# GH 35964
# we are trying to transform with an aggregator
Expand Down

0 comments on commit 47c6d16

Please sign in to comment.