Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make grouping column an agg #6997

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 21 additions & 15 deletions pandas/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,9 +673,17 @@ def var(self, ddof=1):
def size(self):
"""
Compute group sizes

"""
return self.grouper.size()

def count(self):
"""
Number of non-null items in each group.

"""
return self._python_agg_general(lambda x: notnull(x).sum())

sum = _groupby_function('sum', 'add', np.sum)
prod = _groupby_function('prod', 'prod', np.prod)
min = _groupby_function('min', 'min', np.min, numeric_only=False)
Expand All @@ -687,12 +695,10 @@ def size(self):

def ohlc(self):
"""
Compute sum of values, excluding missing values

For multiple groupings, the result index will be a MultiIndex
Deprecated, use .resample(how="ohlc") instead.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we deprecate it first? (and why is it deprecated? Is there an issue about it?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it doesn't work atm, it raises with strangeish message... see #6594.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, that is a good reason :-)


"""
return self._cython_agg_general('ohlc')
raise AttributeError('ohlc is deprecated, use resample(how="ohlc").')

def nth(self, n, dropna=None):
"""
Expand Down Expand Up @@ -939,6 +945,7 @@ def _cython_agg_general(self, how, numeric_only=True):
result, names = self.grouper.aggregate(obj.values, how)
except AssertionError as e:
raise GroupByError(str(e))
# infer old dytpe
output[name] = self._try_cast(result, obj)

if len(output) == 0:
Expand All @@ -947,6 +954,8 @@ def _cython_agg_general(self, how, numeric_only=True):
return self._wrap_aggregated_output(output, names)

def _python_agg_general(self, func, *args, **kwargs):
_dtype = kwargs.pop("_dtype", None)

func = _intercept_function(func)
f = lambda x: func(x, *args, **kwargs)

Expand All @@ -955,7 +964,14 @@ def _python_agg_general(self, func, *args, **kwargs):
for name, obj in self._iterate_slices():
try:
result, counts = self.grouper.agg_series(obj, f)
output[name] = self._try_cast(result, obj)

if _dtype is None: # infer old dytpe
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is all this necessary? _try_cast does this IIRC

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it looks like agg must upcast to float, i can't recall tbh... also my choice of api is not good.

I'm not even using it, so er... guess I should remove.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yes, it upcast to float... this explains the "wip" commit message

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok but try_cast should still be able to fix it maybe it needs more hints)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can take this - u can lot on plate!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that would be amazing, thanks! This was also attempting to address #6594 btw, looks like ohlc in there twice.

output[name] = self._try_cast(result, obj)
elif _dtype is False:
output[name] = result
else:
output[name] = _possibly_downcast_to_dtype(result, _dtype)

except TypeError:
continue

Expand Down Expand Up @@ -2889,16 +2905,6 @@ def _apply_to_column_groupbys(self, func):
in self._iterate_column_groupbys()),
keys=self._selected_obj.columns, axis=1)

def ohlc(self):
"""
Compute sum of values, excluding missing values

For multiple groupings, the result index will be a MultiIndex
"""
return self._apply_to_column_groupbys(
lambda x: x._cython_agg_general('ohlc'))


from pandas.tools.plotting import boxplot_frame_groupby
DataFrameGroupBy.boxplot = boxplot_frame_groupby

Expand Down
12 changes: 12 additions & 0 deletions pandas/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1970,6 +1970,18 @@ def test_size(self):
for key, group in grouped:
self.assertEquals(result[key], len(group))

def test_count(self):
df = pd.DataFrame([[1, 2], [1, nan], [3, nan]], columns=['A', 'B'])
count_as = df.groupby('A').count()
count_not_as = df.groupby('A', as_index=False).count()

res = pd.DataFrame([[1, 1], [3, 0]], columns=['A', 'B'])
assert_frame_equal(count_not_as, res)
assert_frame_equal(count_as, res.set_index('A'))

count_B = df.groupby('A')['B'].count()
assert_series_equal(count_B, res['B'])

def test_grouping_ndarray(self):
grouped = self.df.groupby(self.df['A'].values)

Expand Down