Skip to content

Commit

Permalink
Raise error if multiple by's are used with Ellipsis (#149)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Illviljan and pre-commit-ci[bot] authored Sep 22, 2022
1 parent ab29d2c commit af3e3ce
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
13 changes: 8 additions & 5 deletions flox/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def xarray_reduce(
if skipna is not None and isinstance(func, Aggregation):
raise ValueError("skipna must be None when func is an Aggregation.")

nby = len(by)
for b in by:
if isinstance(b, xr.DataArray) and b.name is None:
raise ValueError("Cannot group by unnamed DataArrays.")
Expand All @@ -203,11 +204,11 @@ def xarray_reduce(
keep_attrs = True

if isinstance(isbin, bool):
isbin = (isbin,) * len(by)
isbin = (isbin,) * nby
if expected_groups is None:
expected_groups = (None,) * len(by)
expected_groups = (None,) * nby
if isinstance(expected_groups, (np.ndarray, list)): # TODO: test for list
if len(by) == 1:
if nby == 1:
expected_groups = (expected_groups,)
else:
raise ValueError("Needs better message.")
Expand Down Expand Up @@ -239,6 +240,8 @@ def xarray_reduce(
ds = ds.drop_vars([var for var in maybe_drop if var in ds.variables])

if dim is Ellipsis:
if nby > 1:
raise NotImplementedError("Multiple by are not allowed when dim is Ellipsis.")
dim = tuple(obj.dims)
if by[0].name in ds.dims and not isbin[0]:
dim = tuple(d for d in dim if d != by[0].name)
Expand Down Expand Up @@ -351,7 +354,7 @@ def wrapper(array, *by, func, skipna, **kwargs):
missing_dim[k] = v

input_core_dims = _get_input_core_dims(group_names, dim, ds, grouper_dims)
input_core_dims += [input_core_dims[-1]] * (len(by) - 1)
input_core_dims += [input_core_dims[-1]] * (nby - 1)

actual = xr.apply_ufunc(
wrapper,
Expand Down Expand Up @@ -409,7 +412,7 @@ def wrapper(array, *by, func, skipna, **kwargs):
if unindexed_dims:
actual = actual.drop_vars(unindexed_dims)

if len(by) == 1:
if nby == 1:
for var in actual:
if isinstance(obj, xr.DataArray):
template = obj
Expand Down
3 changes: 3 additions & 0 deletions tests/test_xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ def test_xarray_reduce_multiple_groupers_2(pass_expected_groups, chunk, engine):
actual = xarray_reduce(da, "labels", "labels2", **kwargs)
xr.testing.assert_identical(expected, actual)

with pytest.raises(NotImplementedError):
xarray_reduce(da, "labels", "labels2", dim=..., **kwargs)


@requires_dask
def test_dask_groupers_error():
Expand Down

0 comments on commit af3e3ce

Please sign in to comment.