Skip to content

Commit

Permalink
Fix numeric_only logic in frames_test for Pandas 2 (#28422)
Browse files Browse the repository at this point in the history
  • Loading branch information
caneff authored Sep 18, 2023
1 parent 8871a4e commit 71c68ca
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 30 deletions.
22 changes: 16 additions & 6 deletions sdks/python/apache_beam/dataframe/frame_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,8 @@ def wrap(func):

removed_arg_names = removed_args if removed_args is not None else []

# We would need to add position only arguments if they ever become a thing
# in Pandas (as of 2.1 currently they aren't).
base_arg_spec = getfullargspec(unwrap(getattr(base_type, func.__name__)))
base_arg_names = base_arg_spec.args
# Some arguments are keyword only and we still want to check against those.
Expand All @@ -514,6 +516,9 @@ def wrap(func):

@functools.wraps(func)
def wrapper(*args, **kwargs):
if len(args) > len(base_arg_names):
raise TypeError(f"{func.__name__} got too many positioned arguments.")

for name, value in zip(base_arg_names, args):
if name in kwargs:
raise TypeError(
Expand All @@ -523,7 +528,7 @@ def wrapper(*args, **kwargs):
# Still have to populate these for the Beam function signature.
if removed_args:
for name in removed_args:
if not name in kwargs:
if name not in kwargs:
kwargs[name] = None
return func(**kwargs)

Expand Down Expand Up @@ -646,13 +651,18 @@ def wrap(func):
return func

base_argspec = getfullargspec(unwrap(getattr(base_type, func.__name__)))
if not base_argspec.defaults:
if not base_argspec.defaults and not base_argspec.kwonlydefaults:
return func

arg_to_default = dict(
zip(
base_argspec.args[-len(base_argspec.defaults):],
base_argspec.defaults))
arg_to_default = {}
if base_argspec.defaults:
arg_to_default.update(
zip(
base_argspec.args[-len(base_argspec.defaults):],
base_argspec.defaults))

if base_argspec.kwonlydefaults:
arg_to_default.update(base_argspec.kwonlydefaults)

unwrapped_func = unwrap(func)
# args that do not have defaults in func, but do have defaults in base
Expand Down
47 changes: 46 additions & 1 deletion sdks/python/apache_beam/dataframe/frame_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def add_one(frame):

def test_args_to_kwargs(self):
class Base(object):
def func(self, a=1, b=2, c=3):
def func(self, a=1, b=2, c=3, *, kw_only=4):
pass

class Proxy(object):
Expand All @@ -87,6 +87,9 @@ def func(self, **kwargs):
self.assertEqual(proxy.func(2, 4, 6), {'a': 2, 'b': 4, 'c': 6})
self.assertEqual(proxy.func(2, c=6), {'a': 2, 'c': 6})
self.assertEqual(proxy.func(c=6, a=2), {'a': 2, 'c': 6})
self.assertEqual(proxy.func(2, kw_only=20), {'a': 2, 'kw_only': 20})
with self.assertRaises(TypeError): # got too many positioned arguments
proxy.func(2, 4, 6, 8)

def test_args_to_kwargs_populates_defaults(self):
class Base(object):
Expand Down Expand Up @@ -129,6 +132,48 @@ def func_removed_args(self, a, c, **kwargs):
proxy.func_removed_args()
self.assertEqual(proxy.func_removed_args(12, d=100), {'a': 12, 'd': 100})

def test_args_to_kwargs_populates_default_handles_kw_only(self):
class Base(object):
def func(self, a, b=2, c=3, *, kw_only=4):
pass

class ProxyUsesKwOnly(object):
@frame_base.args_to_kwargs(Base)
@frame_base.populate_defaults(Base)
def func(self, a, kw_only, **kwargs):
return dict(kwargs, a=a, kw_only=kw_only)

proxy = ProxyUsesKwOnly()

# pylint: disable=too-many-function-args,no-value-for-parameter
with self.assertRaises(TypeError): # missing 1 required positional argument
proxy.func()

self.assertEqual(proxy.func(100), {'a': 100, 'kw_only': 4})
self.assertEqual(
proxy.func(2, 4, 6, kw_only=8), {
'a': 2, 'b': 4, 'c': 6, 'kw_only': 8
})
with self.assertRaises(TypeError):
proxy.func(2, 4, 6, 8) # got too many positioned arguments

class ProxyDoesntUseKwOnly(object):
@frame_base.args_to_kwargs(Base)
@frame_base.populate_defaults(Base)
def func(self, a, **kwargs):
return dict(kwargs, a=a)

proxy = ProxyDoesntUseKwOnly()

# pylint: disable=too-many-function-args,no-value-for-parameter
with self.assertRaises(TypeError): # missing 1 required positional argument
proxy.func()
self.assertEqual(proxy.func(100), {'a': 100})
self.assertEqual(
proxy.func(2, 4, 6, kw_only=8), {
'a': 2, 'b': 4, 'c': 6, 'kw_only': 8
})


if __name__ == '__main__':
unittest.main()
4 changes: 2 additions & 2 deletions sdks/python/apache_beam/dataframe/frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,7 +907,7 @@ def sort_index(self, axis, **kwargs):
return frame_base.DeferredFrame.wrap(
expressions.ComputedExpression(
'sort_index',
lambda df: df.sort_index(axis, **kwargs),
lambda df: df.sort_index(axis=axis, **kwargs),
[self._expr],
requires_partition_by=partitionings.Arbitrary(),
preserves_partition_by=partitionings.Arbitrary(),
Expand Down Expand Up @@ -2689,7 +2689,7 @@ def set_axis(self, labels, axis, **kwargs):
return frame_base.DeferredFrame.wrap(
expressions.ComputedExpression(
'set_axis',
lambda df: df.set_axis(labels, axis, **kwargs),
lambda df: df.set_axis(labels, axis=axis, **kwargs),
[self._expr],
requires_partition_by=partitionings.Arbitrary(),
preserves_partition_by=partitionings.Arbitrary()))
Expand Down
80 changes: 59 additions & 21 deletions sdks/python/apache_beam/dataframe/frames_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import re
import unittest
import warnings
from typing import Dict

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -1634,6 +1635,30 @@ def test_pivot_no_index_provided_on_multiindex(self):
# https://github.com/pandas-dev/pandas/issues/40139
ALL_GROUPING_AGGREGATIONS = sorted(
set(frames.ALL_AGGREGATIONS) - set(('kurt', 'kurtosis')))
AGGREGATIONS_WHERE_NUMERIC_ONLY_DEFAULTS_TO_TRUE_IN_PANDAS_1 = set(
frames.ALL_AGGREGATIONS) - set((
'nunique',
'size',
'count',
'idxmin',
'idxmax',
'mode',
'rank',
'all',
'any',
'describe'))


def numeric_only_kwargs_for_pandas_2(agg_type: str) -> Dict[str, bool]:
"""Get proper arguments for numeric_only.
Behavior for numeric_only in these methods changed in Pandas 2 to default
to False instead of True, so explicitly make it True in Pandas 2."""
if PD_VERSION >= (2, 0) and (
agg_type in AGGREGATIONS_WHERE_NUMERIC_ONLY_DEFAULTS_TO_TRUE_IN_PANDAS_1):
return {'numeric_only': True}
else:
return {}


class GroupByTest(_AbstractFrameTest):
Expand All @@ -1650,8 +1675,9 @@ def test_groupby_agg(self, agg_type):
self.skipTest(
"https://github.com/apache/beam/issues/20967: proxy generation of "
"DataFrameGroupBy.describe fails in pandas < 1.2")
kwargs = numeric_only_kwargs_for_pandas_2(agg_type)
self._run_test(
lambda df: df.groupby('group').agg(agg_type),
lambda df: df.groupby('group').agg(agg_type, **kwargs),
GROUPBY_DF,
check_proxy=False)

Expand All @@ -1661,8 +1687,10 @@ def test_groupby_with_filter(self, agg_type):
self.skipTest(
"https://github.com/apache/beam/issues/20967: proxy generation of "
"DataFrameGroupBy.describe fails in pandas < 1.2")
kwargs = numeric_only_kwargs_for_pandas_2(agg_type)
self._run_test(
lambda df: getattr(df[df.foo > 30].groupby('group'), agg_type)(),
lambda df: getattr(df[df.foo > 30].groupby('group'), agg_type)
(**kwargs),
GROUPBY_DF,
check_proxy=False)

Expand All @@ -1673,8 +1701,9 @@ def test_groupby(self, agg_type):
"https://github.com/apache/beam/issues/20967: proxy generation of "
"DataFrameGroupBy.describe fails in pandas < 1.2")

kwargs = numeric_only_kwargs_for_pandas_2(agg_type)
self._run_test(
lambda df: getattr(df.groupby('group'), agg_type)(),
lambda df: getattr(df.groupby('group'), agg_type)(**kwargs),
GROUPBY_DF,
check_proxy=False)

Expand All @@ -1685,8 +1714,10 @@ def test_groupby_series(self, agg_type):
"https://github.com/apache/beam/issues/20967: proxy generation of "
"DataFrameGroupBy.describe fails in pandas < 1.2")

kwargs = numeric_only_kwargs_for_pandas_2(agg_type)
self._run_test(
lambda df: getattr(df[df.foo > 40].groupby(df.group), agg_type)(),
lambda df: getattr(df[df.foo > 40].groupby(df.group), agg_type)
(**kwargs),
GROUPBY_DF,
check_proxy=False)

Expand Down Expand Up @@ -1717,21 +1748,26 @@ def test_groupby_project_series(self, agg_type):
"https://github.com/apache/beam/issues/20895: "
"SeriesGroupBy.{corr, cov} do not raise the expected error.")

self._run_test(lambda df: getattr(df.groupby('group').foo, agg_type)(), df)
self._run_test(lambda df: getattr(df.groupby('group').bar, agg_type)(), df)
kwargs = numeric_only_kwargs_for_pandas_2(agg_type)
self._run_test(
lambda df: getattr(df.groupby('group')['foo'], agg_type)(), df)
lambda df: getattr(df.groupby('group').foo, agg_type)(**kwargs), df)
self._run_test(
lambda df: getattr(df.groupby('group')['bar'], agg_type)(), df)
lambda df: getattr(df.groupby('group').bar, agg_type)(**kwargs), df)
self._run_test(
lambda df: getattr(df.groupby('group')['foo'], agg_type)(**kwargs), df)
self._run_test(
lambda df: getattr(df.groupby('group')['bar'], agg_type)(**kwargs), df)

@parameterized.expand(ALL_GROUPING_AGGREGATIONS)
def test_groupby_project_dataframe(self, agg_type):
if agg_type == 'describe' and PD_VERSION < (1, 2):
self.skipTest(
"https://github.com/apache/beam/issues/20967: proxy generation of "
"DataFrameGroupBy.describe fails in pandas < 1.2")
kwargs = numeric_only_kwargs_for_pandas_2(agg_type)
self._run_test(
lambda df: getattr(df.groupby('group')[['bar', 'baz']], agg_type)(),
lambda df: getattr(df.groupby('group')[['bar', 'baz']], agg_type)
(**kwargs),
GROUPBY_DF,
check_proxy=False)

Expand Down Expand Up @@ -1760,9 +1796,10 @@ def test_groupby_errors_non_existent_label(self):

def test_groupby_callable(self):
df = GROUPBY_DF

self._run_test(lambda df: df.groupby(lambda x: x % 2).foo.sum(), df)
self._run_test(lambda df: df.groupby(lambda x: x % 5).median(), df)
kwargs = numeric_only_kwargs_for_pandas_2('sum')
self._run_test(lambda df: df.groupby(lambda x: x % 2).foo.sum(**kwargs), df)
kwargs = numeric_only_kwargs_for_pandas_2('median')
self._run_test(lambda df: df.groupby(lambda x: x % 5).median(**kwargs), df)

def test_groupby_apply(self):
df = GROUPBY_DF
Expand Down Expand Up @@ -1817,8 +1854,9 @@ def test_groupby_transform(self):

def test_groupby_pipe(self):
df = GROUPBY_DF

self._run_test(lambda df: df.groupby('group').pipe(lambda x: x.sum()), df)
kwargs = numeric_only_kwargs_for_pandas_2('sum')
self._run_test(
lambda df: df.groupby('group').pipe(lambda x: x.sum(**kwargs)), df)
self._run_test(
lambda df: df.groupby('group')['bool'].pipe(lambda x: x.any()), df)
self._run_test(
Expand Down Expand Up @@ -1900,14 +1938,14 @@ def test_dataframe_groupby_series(self, agg_type):
self.skipTest(
"https://github.com/apache/beam/issues/20967: proxy generation of "
"DataFrameGroupBy.describe fails in pandas < 1.2")

def agg(df, group_by):
kwargs = numeric_only_kwargs_for_pandas_2(agg_type)
return df[df.foo > 40].groupby(group_by).agg(agg_type, **kwargs)

self._run_test(lambda df: agg(df, df.group), GROUPBY_DF, check_proxy=False)
self._run_test(
lambda df: df[df.foo > 40].groupby(df.group).agg(agg_type),
GROUPBY_DF,
check_proxy=False)
self._run_test(
lambda df: df[df.foo > 40].groupby(df.foo % 3).agg(agg_type),
GROUPBY_DF,
check_proxy=False)
lambda df: agg(df, df.foo % 3), GROUPBY_DF, check_proxy=False)

@parameterized.expand(ALL_GROUPING_AGGREGATIONS)
def test_series_groupby_series(self, agg_type):
Expand Down

0 comments on commit 71c68ca

Please sign in to comment.