Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix coordinate attr handling in xr.where(..., keep_attrs=True) #7229

Merged
merged 17 commits into from
Nov 30, 2022
Merged
3 changes: 2 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ Deprecations

Bug fixes
~~~~~~~~~

- Fix handling of coordinate attributes in :py:func:`where`. (:issue:`7220`, :pull:`7229`)
By `Sam Levang <https://github.com/slevang>`_.
- Import ``nc_time_axis`` when needed (:issue:`7275`, :pull:`7276`).
By `Michael Niklas <https://github.com/headtr1ck>`_.

Expand Down
26 changes: 21 additions & 5 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1855,15 +1855,13 @@ def where(cond, x, y, keep_attrs=None):
Dataset.where, DataArray.where :
equivalent methods
"""
from .dataset import Dataset

if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=False)
if keep_attrs is True:
# keep the attributes of x, the second parameter, by default to
# be consistent with the `where` method of `DataArray` and `Dataset`
keep_attrs = lambda attrs, context: getattr(x, "attrs", {})

# alignment for three arguments is complicated, so don't support it yet
return apply_ufunc(
result = apply_ufunc(
duck_array_ops.where,
cond,
x,
Expand All @@ -1874,6 +1872,24 @@ def where(cond, x, y, keep_attrs=None):
keep_attrs=keep_attrs,
)

# make sure we have the attrs of x across Dataset, DataArray, and coords
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably a neater way to do this but it seems to work.

if keep_attrs is True:
if isinstance(y, Dataset) and not isinstance(x, Dataset):
# handle special case where x gets promoted to Dataset
result.attrs = {}
if getattr(x, "name", None) in result.data_vars:
result[x.name].attrs = getattr(x, "attrs", {})
else:
# otherwise, fill in global attrs and variable attrs (if they exist)
result.attrs = getattr(x, "attrs", {})
for v in getattr(result, "data_vars", []):
result[v].attrs = getattr(getattr(x, v, None), "attrs", {})
for c in getattr(result, "coords", []):
# always fill coord attrs of x
result[c].attrs = getattr(getattr(x, c, None), "attrs", {})

return result


@overload
def polyval(coord: DataArray, coeffs: DataArray, degree_dim: Hashable) -> DataArray:
Expand Down
59 changes: 53 additions & 6 deletions xarray/tests/test_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1925,16 +1925,63 @@ def test_where() -> None:


def test_where_attrs() -> None:
cond = xr.DataArray([True, False], dims="x", attrs={"attr": "cond"})
x = xr.DataArray([1, 1], dims="x", attrs={"attr": "x"})
y = xr.DataArray([0, 0], dims="x", attrs={"attr": "y"})
cond = xr.DataArray([True, False], coords={"a": [0, 1]}, attrs={"attr": "cond_da"})
cond["a"].attrs = {"attr": "cond_coord"}
x = xr.DataArray([1, 1], coords={"a": [0, 1]}, attrs={"attr": "x_da"})
x["a"].attrs = {"attr": "x_coord"}
y = xr.DataArray([0, 0], coords={"a": [0, 1]}, attrs={"attr": "y_da"})
y["a"].attrs = {"attr": "y_coord"}

# 3 DataArrays, takes attrs from x
actual = xr.where(cond, x, y, keep_attrs=True)
expected = xr.DataArray([1, 0], dims="x", attrs={"attr": "x"})
expected = xr.DataArray([1, 0], coords={"a": [0, 1]}, attrs={"attr": "x_da"})
expected["a"].attrs = {"attr": "x_coord"}
assert_identical(expected, actual)

# ensure keep_attrs can handle scalar values
# x as a scalar, takes no attrs
actual = xr.where(cond, 0, y, keep_attrs=True)
expected = xr.DataArray([0, 0], coords={"a": [0, 1]})
assert_identical(expected, actual)

# y as a scalar, takes attrs from x
actual = xr.where(cond, x, 0, keep_attrs=True)
expected = xr.DataArray([1, 0], coords={"a": [0, 1]}, attrs={"attr": "x_da"})
expected["a"].attrs = {"attr": "x_coord"}
assert_identical(expected, actual)

# x and y as a scalar, takes no attrs
actual = xr.where(cond, 1, 0, keep_attrs=True)
assert actual.attrs == {}
expected = xr.DataArray([1, 0], coords={"a": [0, 1]})
assert_identical(expected, actual)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one seems confusing but I don't have a strong opinion

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now it doesn't take the cond coord attrs. I still think this would be a convenient (albeit confusing) feature, because I happen to have a bunch of code like xr.where(x>10, 10, x) and would rather keep the attrs. But this is obviously a better use case for DataArray.where.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


# cond and y as a scalar, takes attrs from x
actual = xr.where(True, x, y, keep_attrs=True)
expected = xr.DataArray([1, 1], coords={"a": [0, 1]}, attrs={"attr": "x_da"})
expected["a"].attrs = {"attr": "x_coord"}
assert_identical(expected, actual)

# DataArray and 2 Datasets, takes attrs from x
ds_x = xr.Dataset(data_vars={"x": x}, attrs={"attr": "x_ds"})
ds_y = xr.Dataset(data_vars={"x": y}, attrs={"attr": "y_ds"})
actual = xr.where(cond, ds_x, ds_y, keep_attrs=True)
expected = xr.Dataset(
data_vars={
"x": xr.DataArray([1, 0], coords={"a": [0, 1]}, attrs={"attr": "x_da"})
},
attrs={"attr": "x_ds"},
)
expected["a"].attrs = {"attr": "x_coord"}
assert_identical(expected, actual)

# 2 DataArrays and 1 Dataset, takes attrs from x
actual = xr.where(cond, x.rename("x"), ds_y, keep_attrs=True)
expected = xr.Dataset(
data_vars={
"x": xr.DataArray([1, 0], coords={"a": [0, 1]}, attrs={"attr": "x_da"})
},
)
expected["a"].attrs = {"attr": "x_coord"}
assert_identical(expected, actual)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a Dataset test too with Dataset.attrs={"foo": "bar'}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As written [DataArray, Dataset, Dataset] takes the Dataset attrs of y. Not sure how to fix this without going deep into the chain of apply_ufunc calls. I'm starting to think rebuilding all the attrs after apply_ufunc might be easiest to get consistent behavior for now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case should be covered now. Looks like we can pass one DataArray and one Dataset so I've included a test on that, but there are a lot of permutations.



@pytest.mark.parametrize(
Expand Down