Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix coordinate attr handling in xr.where(..., keep_attrs=True) #7229

Merged
merged 17 commits into from
Nov 30, 2022
Merged
3 changes: 2 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ Deprecations

Bug fixes
~~~~~~~~~

- Fix handling of coordinate attributes in :py:func:`where`. (:issue:`7220`, :pull:`7229`)
By `Sam Levang <https://github.com/slevang>`_.
- Import ``nc_time_axis`` when needed (:issue:`7275`, :pull:`7276`).
By `Michael Niklas <https://github.com/headtr1ck>`_.

Expand Down
11 changes: 9 additions & 2 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import numpy as np

from . import dtypes, duck_array_ops, utils
from .alignment import align, deep_align
from .alignment import align, broadcast, deep_align
from .common import zeros_like
from .duck_array_ops import datetime_to_numeric
from .indexes import Index, filter_indexes_from_coords
Expand Down Expand Up @@ -1855,12 +1855,19 @@ def where(cond, x, y, keep_attrs=None):
Dataset.where, DataArray.where :
equivalent methods
"""
from .dataarray import DataArray

if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=False)
if keep_attrs is True:
# keep the attributes of x, the second parameter, by default to
# be consistent with the `where` method of `DataArray` and `Dataset`
keep_attrs = lambda attrs, context: getattr(x, "attrs", {})
keep_attrs = lambda attrs, context: attrs[1]
# cast non-xarray objects to DataArray to get empty attrs
cond, x, y = (v if hasattr(v, "attrs") else DataArray(v) for v in [cond, x, y])
# explicitly broadcast to ensure we also get empty coord attrs
# take coord attrs preferentially from x, then y, then cond
x, y, cond = broadcast(x, y, cond)
dcherian marked this conversation as resolved.
Show resolved Hide resolved

# alignment for three arguments is complicated, so don't support it yet
return apply_ufunc(
Expand Down
38 changes: 32 additions & 6 deletions xarray/tests/test_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1925,16 +1925,42 @@ def test_where() -> None:


def test_where_attrs() -> None:
cond = xr.DataArray([True, False], dims="x", attrs={"attr": "cond"})
x = xr.DataArray([1, 1], dims="x", attrs={"attr": "x"})
y = xr.DataArray([0, 0], dims="x", attrs={"attr": "y"})
cond = xr.DataArray([True, False], coords={"x": [0, 1]}, attrs={"attr": "cond_da"})
cond["x"].attrs = {"attr": "cond_coord"}
x = xr.DataArray([1, 1], coords={"x": [0, 1]}, attrs={"attr": "x_da"})
x["x"].attrs = {"attr": "x_coord"}
y = xr.DataArray([0, 0], coords={"x": [0, 1]}, attrs={"attr": "y_da"})
y["x"].attrs = {"attr": "y_coord"}

# 3 DataArrays, takes attrs from x
actual = xr.where(cond, x, y, keep_attrs=True)
expected = xr.DataArray([1, 0], dims="x", attrs={"attr": "x"})
expected = xr.DataArray([1, 0], coords={"x": [0, 1]}, attrs={"attr": "x_da"})
expected["x"].attrs = {"attr": "x_coord"}
assert_identical(expected, actual)

# ensure keep_attrs can handle scalar values
# x as a scalar, takes coord attrs only from y
actual = xr.where(cond, 0, y, keep_attrs=True)
expected = xr.DataArray([0, 0], coords={"x": [0, 1]})
expected["x"].attrs = {"attr": "y_coord"}
assert_identical(expected, actual)

# y as a scalar, takes attrs from x
actual = xr.where(cond, x, 0, keep_attrs=True)
expected = xr.DataArray([1, 0], coords={"x": [0, 1]}, attrs={"attr": "x_da"})
expected["x"].attrs = {"attr": "x_coord"}
assert_identical(expected, actual)

# x and y as a scalar, takes coord attrs only from cond
actual = xr.where(cond, 1, 0, keep_attrs=True)
assert actual.attrs == {}
expected = xr.DataArray([1, 0], coords={"x": [0, 1]})
expected["x"].attrs = {"attr": "cond_coord"}
assert_identical(expected, actual)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one seems confusing but I don't have a strong opinion

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now it doesn't take the cond coord attrs. I still think this would be a convenient (albeit confusing) feature, because I happen to have a bunch of code like xr.where(x>10, 10, x) and would rather keep the attrs. But this is obviously a better use case for DataArray.where.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


# cond and y as a scalar, takes attrs from x
actual = xr.where(True, x, y, keep_attrs=True)
expected = xr.DataArray([1, 1], coords={"x": [0, 1]}, attrs={"attr": "x_da"})
expected["x"].attrs = {"attr": "x_coord"}
assert_identical(expected, actual)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a Dataset test too with Dataset.attrs={"foo": "bar'}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As written [DataArray, Dataset, Dataset] takes the Dataset attrs of y. Not sure how to fix this without going deep into the chain of apply_ufunc calls. I'm starting to think rebuilding all the attrs after apply_ufunc might be easiest to get consistent behavior for now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case should be covered now. Looks like we can pass one DataArray and one Dataset so I've included a test on that, but there are a lot of permutations.



@pytest.mark.parametrize(
Expand Down