Skip to content

Commit

Permalink
Allowed kwargs to pass through to Cython func
Browse files Browse the repository at this point in the history
  • Loading branch information
WillAyd committed Jan 31, 2018
1 parent dfd1549 commit 01468d1
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 11 deletions.
2 changes: 1 addition & 1 deletion pandas/_libs/groupby_helper.pxi.in
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ def group_nth_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
def group_rank_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
ndarray[{{c_type}}, ndim=2] values,
ndarray[int64_t] labels,
bint is_datetimelike):
bint is_datetimelike, **kwargs):
"""
Only transforms on axis=0
"""
Expand Down
25 changes: 15 additions & 10 deletions pandas/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,15 +982,15 @@ def _transform_should_cast(self, func_nm):
return (self.size().fillna(0) > 0).any() and (func_nm not in
_cython_cast_blacklist)

def _cython_transform(self, how, numeric_only=True):
def _cython_transform(self, how, numeric_only=True, **kwargs):
output = collections.OrderedDict()
for name, obj in self._iterate_slices():
is_numeric = is_numeric_dtype(obj.dtype)
if numeric_only and not is_numeric:
continue

try:
result, names = self.grouper.transform(obj.values, how)
result, names = self.grouper.transform(obj.values, how, **kwargs)
except NotImplementedError:
continue
except AssertionError as e:
Expand Down Expand Up @@ -1758,9 +1758,12 @@ def cumcount(self, ascending=True):

@Substitution(name='groupby')
@Appender(_doc_template)
def rank(self, axis=0, *args, **kwargs):
def rank(self, ties_method='average', ascending=True, na_option='keep',
pct=False, axis=0):
"""Rank within each group"""
return self._cython_transform('rank', **kwargs)
return self._cython_transform('rank', ties_method=ties_method,
ascending=ascending, na_option=na_option,
pct=pct, axis=axis)

@Substitution(name='groupby')
@Appender(_doc_template)
Expand Down Expand Up @@ -2237,7 +2240,8 @@ def wrapper(*args, **kwargs):
(how, dtype_str))
return func, dtype_str

def _cython_operation(self, kind, values, how, axis, min_count=-1):
def _cython_operation(self, kind, values, how, axis, min_count=-1,
**kwargs):
assert kind in ['transform', 'aggregate']

# can we do this operation with our cython functions
Expand Down Expand Up @@ -2329,7 +2333,8 @@ def _cython_operation(self, kind, values, how, axis, min_count=-1):

# TODO: min_count
result = self._transform(
result, values, labels, func, is_numeric, is_datetimelike)
result, values, labels, func, is_numeric, is_datetimelike,
**kwargs)

if is_integer_dtype(result):
mask = result == iNaT
Expand Down Expand Up @@ -2368,8 +2373,8 @@ def aggregate(self, values, how, axis=0, min_count=-1):
return self._cython_operation('aggregate', values, how, axis,
min_count=min_count)

def transform(self, values, how, axis=0):
return self._cython_operation('transform', values, how, axis)
def transform(self, values, how, axis=0, **kwargs):
return self._cython_operation('transform', values, how, axis, **kwargs)

def _aggregate(self, result, counts, values, comp_ids, agg_func,
is_numeric, is_datetimelike, min_count=-1):
Expand All @@ -2389,7 +2394,7 @@ def _aggregate(self, result, counts, values, comp_ids, agg_func,
return result

def _transform(self, result, values, comp_ids, transform_func,
is_numeric, is_datetimelike):
is_numeric, is_datetimelike, **kwargs):

comp_ids, _, ngroups = self.group_info
if values.ndim > 3:
Expand All @@ -2403,7 +2408,7 @@ def _transform(self, result, values, comp_ids, transform_func,
transform_func(result[:, :, i], values,
comp_ids, is_datetimelike)
else:
transform_func(result, values, comp_ids, is_datetimelike)
transform_func(result, values, comp_ids, is_datetimelike, **kwargs)

return result

Expand Down

0 comments on commit 01468d1

Please sign in to comment.