Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

REG: DataFrame/Series.transform with list and non-list dict values #40090

Merged
merged 3 commits into from
Feb 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/source/whatsnew/v1.2.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ Fixed regressions
Passing ``ascending=None`` is still considered invalid,
and the new error message suggests a proper usage
(``ascending`` must be a boolean or a list-like boolean).
- Fixed regression in :meth:`DataFrame.transform` and :meth:`Series.transform` giving incorrect column labels when passed a dictionary with a mix of list and non-list values (:issue:`40018`)
-

.. ---------------------------------------------------------------------------

Expand Down
49 changes: 26 additions & 23 deletions pandas/core/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def transform_dict_like(self, func):
if len(func) == 0:
raise ValueError("No transform functions were provided")

self.validate_dictlike_arg("transform", obj, func)
func = self.normalize_dictlike_arg("transform", obj, func)

results: Dict[Hashable, FrameOrSeriesUnion] = {}
for name, how in func.items():
Expand Down Expand Up @@ -421,32 +421,17 @@ def agg_dict_like(self, _axis: int) -> FrameOrSeriesUnion:
-------
Result of aggregation.
"""
from pandas.core.reshape.concat import concat

obj = self.obj
arg = cast(AggFuncTypeDict, self.f)

is_aggregator = lambda x: isinstance(x, (list, tuple, dict))

if _axis != 0: # pragma: no cover
raise ValueError("Can only pass dict with axis=0")

selected_obj = obj._selected_obj

self.validate_dictlike_arg("agg", selected_obj, arg)

# if we have a dict of any non-scalars
# eg. {'A' : ['mean']}, normalize all to
# be list-likes
# Cannot use arg.values() because arg may be a Series
if any(is_aggregator(x) for _, x in arg.items()):
new_arg: AggFuncTypeDict = {}
for k, v in arg.items():
if not isinstance(v, (tuple, list, dict)):
new_arg[k] = [v]
else:
new_arg[k] = v
arg = new_arg

from pandas.core.reshape.concat import concat
arg = self.normalize_dictlike_arg("agg", selected_obj, arg)

if selected_obj.ndim == 1:
# key only used for output
Expand Down Expand Up @@ -540,14 +525,15 @@ def maybe_apply_multiple(self) -> Optional[FrameOrSeriesUnion]:
return None
return self.obj.aggregate(self.f, self.axis, *self.args, **self.kwargs)

def validate_dictlike_arg(
def normalize_dictlike_arg(
self, how: str, obj: FrameOrSeriesUnion, func: AggFuncTypeDict
) -> None:
) -> AggFuncTypeDict:
"""
Raise if dict-like argument is invalid.
Handler for dict-like argument.

Ensures that necessary columns exist if obj is a DataFrame, and
that a nested renamer is not passed.
that a nested renamer is not passed. Also normalizes to all lists
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe not here if too many changes for a backport, but maybe change function name to say normalize_dictlike_arg

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@simonjayhawkins Thanks for the suggestion. That's actually how I started this out, but then changed it to validate recalling util._validators (some of those do only validate, but a bunch also normalize). I'm seeing a mix of normalize or validate used. I'll change to normalize here, but maybe we should make this consistent.

Related: #19171

when values consists of a mix of list and non-lists.
"""
assert how in ("apply", "agg", "transform")

Expand All @@ -567,6 +553,23 @@ def validate_dictlike_arg(
cols_sorted = list(safe_sort(list(cols)))
raise KeyError(f"Column(s) {cols_sorted} do not exist")

is_aggregator = lambda x: isinstance(x, (list, tuple, dict))

# if we have a dict of any non-scalars
# eg. {'A' : ['mean']}, normalize all to
# be list-likes
# Cannot use func.values() because arg may be a Series
if any(is_aggregator(x) for _, x in func.items()):
new_func: AggFuncTypeDict = {}
for k, v in func.items():
if not is_aggregator(v):
# mypy can't realize v is not a list here
new_func[k] = [v] # type:ignore[list-item]
else:
new_func[k] = v
func = new_func
return func


class FrameApply(Apply):
obj: DataFrame
Expand Down
11 changes: 11 additions & 0 deletions pandas/tests/apply/test_frame_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,17 @@ def test_transform_dictlike(axis, float_frame, box):
tm.assert_frame_equal(result, expected)


def test_transform_dictlike_mixed():
# GH 40018 - mix of lists and non-lists in values of a dictionary
df = DataFrame({"a": [1, 2], "b": [1, 4], "c": [1, 4]})
result = df.transform({"b": ["sqrt", "abs"], "c": "sqrt"})
expected = DataFrame(
[[1.0, 1, 1.0], [2.0, 4, 2.0]],
columns=MultiIndex([("b", "c"), ("sqrt", "abs")], [(0, 0, 1), (0, 1, 0)]),
)
tm.assert_frame_equal(result, expected)


@pytest.mark.parametrize(
"ops",
[
Expand Down
13 changes: 13 additions & 0 deletions pandas/tests/apply/test_series_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import pytest

from pandas import (
DataFrame,
MultiIndex,
Series,
concat,
)
Expand Down Expand Up @@ -55,6 +57,17 @@ def test_transform_dictlike(string_series, box):
tm.assert_frame_equal(result, expected)


def test_transform_dictlike_mixed():
# GH 40018 - mix of lists and non-lists in values of a dictionary
df = Series([1, 4])
result = df.transform({"b": ["sqrt", "abs"], "c": "sqrt"})
expected = DataFrame(
[[1.0, 1, 1.0], [2.0, 4, 2.0]],
columns=MultiIndex([("b", "c"), ("sqrt", "abs")], [(0, 0, 1), (0, 1, 0)]),
)
tm.assert_frame_equal(result, expected)


def test_transform_wont_agg(string_series):
# GH 35964
# we are trying to transform with an aggregator
Expand Down