Skip to content

Commit

Permalink
keep attrs in xarray.where (#4687)
Browse files Browse the repository at this point in the history
Co-authored-by: Deepak Cherian <[email protected]>
Co-authored-by: Illviljan <[email protected]>
Co-authored-by: Maximilian Roos <[email protected]>
  • Loading branch information
4 people authored Jan 19, 2022
1 parent d3b6aa6 commit 84961e6
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 6 deletions.
3 changes: 2 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ New Features
~~~~~~~~~~~~
- New top-level function :py:func:`cross`. (:issue:`3279`, :pull:`5365`).
By `Jimmy Westling <https://github.com/illviljan>`_.

- ``keep_attrs`` support for :py:func:`where` (:issue:`4141`, :issue:`4682`, :pull:`4687`).
By `Justus Magin <https://github.com/keewis>`_.
- Enable the limit option for dask array in the following methods :py:meth:`DataArray.ffill`, :py:meth:`DataArray.bfill`, :py:meth:`Dataset.ffill` and :py:meth:`Dataset.bfill` (:issue:`6112`)
By `Joseph Nowak <https://github.com/josephnowak>`_.

Expand Down
13 changes: 12 additions & 1 deletion xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1727,7 +1727,7 @@ def dot(*arrays, dims=None, **kwargs):
return result.transpose(*all_dims, missing_dims="ignore")


def where(cond, x, y):
def where(cond, x, y, keep_attrs=None):
"""Return elements from `x` or `y` depending on `cond`.
Performs xarray-like broadcasting across input arguments.
Expand All @@ -1743,6 +1743,8 @@ def where(cond, x, y):
values to choose from where `cond` is True
y : scalar, array, Variable, DataArray or Dataset
values to choose from where `cond` is False
keep_attrs : bool or str or callable, optional
How to treat attrs. If True, keep the attrs of `x`.
Returns
-------
Expand Down Expand Up @@ -1808,6 +1810,14 @@ def where(cond, x, y):
Dataset.where, DataArray.where :
equivalent methods
"""
if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=False)

if keep_attrs is True:
# keep the attributes of x, the second parameter, by default to
# be consistent with the `where` method of `DataArray` and `Dataset`
keep_attrs = lambda attrs, context: attrs[1]

# alignment for three arguments is complicated, so don't support it yet
return apply_ufunc(
duck_array_ops.where,
Expand All @@ -1817,6 +1827,7 @@ def where(cond, x, y):
join="exact",
dataset_join="exact",
dask="allowed",
keep_attrs=keep_attrs,
)


Expand Down
9 changes: 9 additions & 0 deletions xarray/tests/test_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1922,6 +1922,15 @@ def test_where() -> None:
assert_identical(expected, actual)


def test_where_attrs() -> None:
cond = xr.DataArray([True, False], dims="x", attrs={"attr": "cond"})
x = xr.DataArray([1, 1], dims="x", attrs={"attr": "x"})
y = xr.DataArray([0, 0], dims="x", attrs={"attr": "y"})
actual = xr.where(cond, x, y, keep_attrs=True)
expected = xr.DataArray([1, 0], dims="x", attrs={"attr": "x"})
assert_identical(expected, actual)


@pytest.mark.parametrize("use_dask", [True, False])
@pytest.mark.parametrize("use_datetime", [True, False])
def test_polyval(use_dask, use_datetime) -> None:
Expand Down
5 changes: 1 addition & 4 deletions xarray/tests/test_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -2429,10 +2429,7 @@ def test_binary_operations(self, func, dtype):
(
pytest.param(operator.lt, id="less_than"),
pytest.param(operator.ge, id="greater_equal"),
pytest.param(
operator.eq,
id="equal",
),
pytest.param(operator.eq, id="equal"),
),
)
@pytest.mark.parametrize(
Expand Down

0 comments on commit 84961e6

Please sign in to comment.