Skip to content

Commit

Permalink
Fix groupby binary ops when grouped array is subset relative to other (
Browse files Browse the repository at this point in the history
…#7798)

* Fix groupby binary ops when grouped array is subset relative to other

Closes #7797

* Fix tests

Co-authored-by: Alan Brammer <[email protected]>
Co-authored-by: Mick <[email protected]>

* fix doc build

* [skip-ci] Update doc/whats-new.rst

---------

Co-authored-by: Alan Brammer <[email protected]>
Co-authored-by: Mick <[email protected]>
  • Loading branch information
3 people authored May 2, 2023
1 parent 6d17fa0 commit ca84a1e
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 4 deletions.
5 changes: 3 additions & 2 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ New Features
~~~~~~~~~~~~
- Added new method :py:meth:`DataArray.to_dask_dataframe`, convert a dataarray into a dask dataframe (:issue:`7409`).
By `Deeksha <https://github.com/dsgreen2>`_.
- Add support for lshift and rshift binary operators (`<<`, `>>`) on
- Add support for lshift and rshift binary operators (``<<``, ``>>``) on
:py:class:`xr.DataArray` of type :py:class:`int` (:issue:`7727` , :pull:`7741`).
By `Alan Brammer <https://github.com/abrammer>`_.

Expand All @@ -40,7 +40,8 @@ Deprecations

Bug fixes
~~~~~~~~~

- Fix groupby binary ops when grouped array is subset relative to other. (:issue:`7797`).
By `Deepak Cherian <https://github.com/dcherian>`_.

Documentation
~~~~~~~~~~~~~
Expand Down
4 changes: 3 additions & 1 deletion xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,9 @@ def _binary_op(self, other, f, reflexive=False):
obj = obj.where(~mask, drop=True)
codes = codes.where(~mask, drop=True).astype(int)

other, _ = align(other, coord, join="outer")
# codes are defined for coord, so we align `other` with `coord`
# before indexing
other, _ = align(other, coord, join="right")
expanded = other.isel({name: codes})

result = g(obj, expanded)
Expand Down
48 changes: 47 additions & 1 deletion xarray/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,11 +829,33 @@ def test_groupby_math_bitshift() -> None:
}
)

left_manual = []
for lev, group in ds.groupby("level"):
shifter = shift.sel(level=lev)
left_manual.append(group << shifter)
left_actual = xr.concat(left_manual, dim="index").reset_coords(names="level")
assert_equal(left_expected, left_actual)

left_actual = (ds.groupby("level") << shift).reset_coords(names="level")
assert_equal(left_expected, left_actual)

right_expected = Dataset(
{
"x": ("index", [0, 0, 2, 2]),
"y": ("index", [-1, -1, -2, -2]),
"level": ("index", [0, 0, 4, 4]),
"index": [0, 1, 2, 3],
}
)
right_manual = []
for lev, group in left_expected.groupby("level"):
shifter = shift.sel(level=lev)
right_manual.append(group >> shifter)
right_actual = xr.concat(right_manual, dim="index").reset_coords(names="level")
assert_equal(right_expected, right_actual)

right_actual = (left_expected.groupby("level") >> shift).reset_coords(names="level")
assert_equal(ds, right_actual)
assert_equal(right_expected, right_actual)


@pytest.mark.parametrize("use_flox", [True, False])
Expand Down Expand Up @@ -1302,8 +1324,15 @@ def test_groupby_math_not_aligned(self):
expected = DataArray([10, 11, np.nan, np.nan], array.coords)
assert_identical(expected, actual)

# regression test for #7797
other = array.groupby("b").sum()
actual = array.sel(x=[0, 1]).groupby("b") - other
expected = DataArray([-1, 0], {"b": ("x", [0, 0]), "x": [0, 1]}, dims="x")
assert_identical(expected, actual)

other = DataArray([10], coords={"c": 123, "b": [0]}, dims="b")
actual = array.groupby("b") + other
expected = DataArray([10, 11, np.nan, np.nan], array.coords)
expected.coords["c"] = (["x"], [123] * 2 + [np.nan] * 2)
assert_identical(expected, actual)

Expand Down Expand Up @@ -2289,3 +2318,20 @@ def test_resample_cumsum(method: str, expected_array: list[float]) -> None:
actual = getattr(ds.foo.resample(time="3M"), method)(dim="time")
expected.coords["time"] = ds.time
assert_identical(expected.drop_vars(["time"]).foo, actual)


def test_groupby_binary_op_regression() -> None:
# regression test for #7797
# monthly timeseries that should return "zero anomalies" everywhere
time = xr.date_range("2023-01-01", "2023-12-31", freq="MS")
data = np.linspace(-1, 1, 12)
x = xr.DataArray(data, coords={"time": time})
clim = xr.DataArray(data, coords={"month": np.arange(1, 13, 1)})

# seems to give the correct result if we use the full x, but not with a slice
x_slice = x.sel(time=["2023-04-01"])

# two typical ways of computing anomalies
anom_gb = x_slice.groupby("time.month") - clim

assert_identical(xr.zeros_like(anom_gb), anom_gb)

0 comments on commit ca84a1e

Please sign in to comment.