Skip to content

Commit

Permalink
implement _BaseGrouper (pandas-dev#29520)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored and proost committed Dec 19, 2019
1 parent 8462998 commit a59ee93
Showing 1 changed file with 21 additions and 34 deletions.
55 changes: 21 additions & 34 deletions pandas/_libs/reduction.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,26 @@ cdef class Reducer:
return result


cdef class SeriesBinGrouper:
cdef class _BaseGrouper:
cdef _check_dummy(self, dummy):
# both values and index must be an ndarray!

values = dummy.values
# GH 23683: datetimetz types are equivalent to datetime types here
if (dummy.dtype != self.arr.dtype
and values.dtype != self.arr.dtype):
raise ValueError('Dummy array must be same dtype')
if util.is_array(values) and not values.flags.contiguous:
# e.g. Categorical has no `flags` attribute
values = values.copy()
index = dummy.index.values
if not index.flags.contiguous:
index = index.copy()

return values, index


cdef class SeriesBinGrouper(_BaseGrouper):
"""
Performs grouping operation according to bin edges, rather than labels
"""
Expand Down Expand Up @@ -216,21 +235,6 @@ cdef class SeriesBinGrouper:
else:
self.ngroups = len(bins) + 1

cdef _check_dummy(self, dummy):
# both values and index must be an ndarray!

values = dummy.values
if values.dtype != self.arr.dtype:
raise ValueError('Dummy array must be same dtype')
if util.is_array(values) and not values.flags.contiguous:
# e.g. Categorical has no `flags` attribute
values = values.copy()
index = dummy.index.values
if not index.flags.contiguous:
index = index.copy()

return values, index

def get_result(self):
cdef:
ndarray arr, result
Expand Down Expand Up @@ -304,7 +308,7 @@ cdef class SeriesBinGrouper:
return result, counts


cdef class SeriesGrouper:
cdef class SeriesGrouper(_BaseGrouper):
"""
Performs generic grouping operation while avoiding ndarray construction
overhead
Expand Down Expand Up @@ -340,23 +344,6 @@ cdef class SeriesGrouper:
self.dummy_arr, self.dummy_index = self._check_dummy(dummy)
self.ngroups = ngroups

cdef _check_dummy(self, dummy):
# both values and index must be an ndarray!

values = dummy.values
# GH 23683: datetimetz types are equivalent to datetime types here
if (dummy.dtype != self.arr.dtype
and values.dtype != self.arr.dtype):
raise ValueError('Dummy array must be same dtype')
if util.is_array(values) and not values.flags.contiguous:
# e.g. Categorical has no `flags` attribute
values = values.copy()
index = dummy.index.values
if not index.flags.contiguous:
index = index.copy()

return values, index

def get_result(self):
cdef:
# Define result to avoid UnboundLocalError
Expand Down

0 comments on commit a59ee93

Please sign in to comment.