Skip to content

Commit

Permalink
BUG: Better handling of invalid na_option argument for groupby.rank(#…
Browse files Browse the repository at this point in the history
  • Loading branch information
peterpanmj committed Jul 30, 2018
1 parent d30c4a0 commit ca106c3
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 26 deletions.
3 changes: 3 additions & 0 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1705,6 +1705,9 @@ def rank(self, method='average', ascending=True, na_option='keep',
-----
DataFrame with ranking of values within each group
"""
if na_option not in {'keep', 'top', 'bottom'}:
msg = "na_option must be one of 'keep', 'top', or 'bottom'"
raise ValueError(msg)
return self._cython_transform('rank', numeric_only=False,
ties_method=method, ascending=ascending,
na_option=na_option, pct=pct, axis=axis)
Expand Down
62 changes: 36 additions & 26 deletions pandas/tests/groupby/test_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,35 +172,35 @@ def test_infs_n_nans(grps, vals, ties_method, ascending, na_option, exp):
[3., 3., np.nan, 1., 3., 2., np.nan, np.nan]),
('dense', False, 'keep', True,
[3. / 3., 3. / 3., np.nan, 1. / 3., 3. / 3., 2. / 3., np.nan, np.nan]),
('average', True, 'no_na', False, [2., 2., 7., 5., 2., 4., 7., 7.]),
('average', True, 'no_na', True,
('average', True, 'bottom', False, [2., 2., 7., 5., 2., 4., 7., 7.]),
('average', True, 'bottom', True,
[0.25, 0.25, 0.875, 0.625, 0.25, 0.5, 0.875, 0.875]),
('average', False, 'no_na', False, [4., 4., 7., 1., 4., 2., 7., 7.]),
('average', False, 'no_na', True,
('average', False, 'bottom', False, [4., 4., 7., 1., 4., 2., 7., 7.]),
('average', False, 'bottom', True,
[0.5, 0.5, 0.875, 0.125, 0.5, 0.25, 0.875, 0.875]),
('min', True, 'no_na', False, [1., 1., 6., 5., 1., 4., 6., 6.]),
('min', True, 'no_na', True,
('min', True, 'bottom', False, [1., 1., 6., 5., 1., 4., 6., 6.]),
('min', True, 'bottom', True,
[0.125, 0.125, 0.75, 0.625, 0.125, 0.5, 0.75, 0.75]),
('min', False, 'no_na', False, [3., 3., 6., 1., 3., 2., 6., 6.]),
('min', False, 'no_na', True,
('min', False, 'bottom', False, [3., 3., 6., 1., 3., 2., 6., 6.]),
('min', False, 'bottom', True,
[0.375, 0.375, 0.75, 0.125, 0.375, 0.25, 0.75, 0.75]),
('max', True, 'no_na', False, [3., 3., 8., 5., 3., 4., 8., 8.]),
('max', True, 'no_na', True,
('max', True, 'bottom', False, [3., 3., 8., 5., 3., 4., 8., 8.]),
('max', True, 'bottom', True,
[0.375, 0.375, 1., 0.625, 0.375, 0.5, 1., 1.]),
('max', False, 'no_na', False, [5., 5., 8., 1., 5., 2., 8., 8.]),
('max', False, 'no_na', True,
('max', False, 'bottom', False, [5., 5., 8., 1., 5., 2., 8., 8.]),
('max', False, 'bottom', True,
[0.625, 0.625, 1., 0.125, 0.625, 0.25, 1., 1.]),
('first', True, 'no_na', False, [1., 2., 6., 5., 3., 4., 7., 8.]),
('first', True, 'no_na', True,
('first', True, 'bottom', False, [1., 2., 6., 5., 3., 4., 7., 8.]),
('first', True, 'bottom', True,
[0.125, 0.25, 0.75, 0.625, 0.375, 0.5, 0.875, 1.]),
('first', False, 'no_na', False, [3., 4., 6., 1., 5., 2., 7., 8.]),
('first', False, 'no_na', True,
('first', False, 'bottom', False, [3., 4., 6., 1., 5., 2., 7., 8.]),
('first', False, 'bottom', True,
[0.375, 0.5, 0.75, 0.125, 0.625, 0.25, 0.875, 1.]),
('dense', True, 'no_na', False, [1., 1., 4., 3., 1., 2., 4., 4.]),
('dense', True, 'no_na', True,
('dense', True, 'bottom', False, [1., 1., 4., 3., 1., 2., 4., 4.]),
('dense', True, 'bottom', True,
[0.25, 0.25, 1., 0.75, 0.25, 0.5, 1., 1.]),
('dense', False, 'no_na', False, [3., 3., 4., 1., 3., 2., 4., 4.]),
('dense', False, 'no_na', True,
('dense', False, 'bottom', False, [3., 3., 4., 1., 3., 2., 4., 4.]),
('dense', False, 'bottom', True,
[0.75, 0.75, 1., 0.25, 0.75, 0.5, 1., 1.])
])
def test_rank_args_missing(grps, vals, ties_method, ascending,
Expand Down Expand Up @@ -252,14 +252,24 @@ def test_rank_object_raises(ties_method, ascending, na_option,
with tm.assert_raises_regex(TypeError, "not callable"):
df.groupby('key').rank(method=ties_method,
ascending=ascending,
na_option='bad', pct=pct)
na_option=na_option, pct=pct)

with tm.assert_raises_regex(TypeError, "not callable"):
df.groupby('key').rank(method=ties_method,
ascending=ascending,
na_option=True, pct=pct)

with tm.assert_raises_regex(TypeError, "not callable"):
@pytest.mark.parametrize("na_option", [True, "bad", 1])
@pytest.mark.parametrize("ties_method", [
'average', 'min', 'max', 'first', 'dense'])
@pytest.mark.parametrize("ascending", [True, False])
@pytest.mark.parametrize("pct", [True, False])
@pytest.mark.parametrize("vals", [
['bar', 'bar', 'foo', 'bar', 'baz'],
['bar', np.nan, 'foo', np.nan, 'baz'],
[1, np.nan, 2, np.nan, 3]
])
def test_rank_naoption_raises(ties_method, ascending, na_option, pct, vals):
df = DataFrame({'key': ['foo'] * 5, 'val': vals})
msg = "na_option must be one of 'keep', 'top', or 'bottom'"

with tm.assert_raises_regex(ValueError, msg):
df.groupby('key').rank(method=ties_method,
ascending=ascending,
na_option=na_option, pct=pct)

0 comments on commit ca106c3

Please sign in to comment.