Skip to content

Commit

Permalink
consolidate dim checks (pandas-dev#29536)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored and proost committed Dec 19, 2019
1 parent d56200d commit 007b9d4
Showing 1 changed file with 11 additions and 14 deletions.
25 changes: 11 additions & 14 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def get_group_levels(self):

_cython_arity = {"ohlc": 4} # OHLC

_name_functions = {"ohlc": lambda *args: ["open", "high", "low", "close"]}
_name_functions = {"ohlc": ["open", "high", "low", "close"]}

def _is_builtin_func(self, arg):
"""
Expand Down Expand Up @@ -433,6 +433,13 @@ def _cython_operation(
assert kind in ["transform", "aggregate"]
orig_values = values

if values.ndim > 2:
raise NotImplementedError("number of dimensions is currently limited to 2")
elif values.ndim == 2:
# Note: it is *not* the case that axis is always 0 for 1-dim values,
# as we can have 1D ExtensionArrays that we need to treat as 2D
assert axis == 1, axis

# can we do this operation with our cython functions
# if not raise NotImplementedError

Expand Down Expand Up @@ -545,10 +552,7 @@ def _cython_operation(
if vdim == 1 and arity == 1:
result = result[:, 0]

if how in self._name_functions:
names = self._name_functions[how]() # type: Optional[List[str]]
else:
names = None
names = self._name_functions.get(how, None) # type: Optional[List[str]]

if swapped:
result = result.swapaxes(0, axis)
Expand Down Expand Up @@ -578,10 +582,7 @@ def _aggregate(
is_datetimelike: bool,
min_count: int = -1,
):
if values.ndim > 2:
# punting for now
raise NotImplementedError("number of dimensions is currently limited to 2")
elif agg_func is libgroupby.group_nth:
if agg_func is libgroupby.group_nth:
# different signature from the others
# TODO: should we be using min_count instead of hard-coding it?
agg_func(result, counts, values, comp_ids, rank=1, min_count=-1)
Expand All @@ -595,11 +596,7 @@ def _transform(
):

comp_ids, _, ngroups = self.group_info
if values.ndim > 2:
# punting for now
raise NotImplementedError("number of dimensions is currently limited to 2")
else:
transform_func(result, values, comp_ids, ngroups, is_datetimelike, **kwargs)
transform_func(result, values, comp_ids, ngroups, is_datetimelike, **kwargs)

return result

Expand Down

0 comments on commit 007b9d4

Please sign in to comment.