Skip to content

Commit

Permalink
Add suggested changes to make apply_ufunc keep attrs
Browse files Browse the repository at this point in the history
  • Loading branch information
max-sixty committed Sep 18, 2023
1 parent 8974453 commit d27bff4
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 10 deletions.
4 changes: 3 additions & 1 deletion xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,7 @@ def apply_dict_of_variables_vfunc(
join="inner",
fill_value=None,
on_missing_core_dim: MissingCoreDimOptions = "raise",
keep_attrs="override",
):
"""Apply a variable level function over dicts of DataArray, DataArray,
Variable and ndarray objects.
Expand All @@ -445,7 +446,7 @@ def apply_dict_of_variables_vfunc(
for name, variable_args in zip(names, grouped_by_name):
core_dim_present = _check_core_dims(signature, variable_args, name)
if core_dim_present is True:
result_vars[name] = func(*variable_args)
result_vars[name] = func(*variable_args, keep_attrs=keep_attrs)
else:
if on_missing_core_dim == "raise":
raise ValueError(core_dim_present)
Expand Down Expand Up @@ -522,6 +523,7 @@ def apply_dataset_vfunc(
join=dataset_join,
fill_value=fill_value,
on_missing_core_dim=on_missing_core_dim,
keep_attrs=keep_attrs,
)

out: Dataset | tuple[Dataset, ...]
Expand Down
12 changes: 3 additions & 9 deletions xarray/tests/test_rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,18 +772,13 @@ def test_rolling_exp_keep_attrs(self, ds) -> None:
# discard attrs
result = ds.rolling_exp(time=10).mean(keep_attrs=False)
assert result.attrs == {}
# TODO: from #8114 — this arguably should be empty, but `apply_ufunc` doesn't do
# that at the moment. We should change in `apply_func` rather than
# special-case it here.
#
# assert result.z1.attrs == {}
assert result.z1.attrs == {}

# test discard attrs using global option
with set_options(keep_attrs=False):
result = ds.rolling_exp(time=10).mean()
assert result.attrs == {}
# See above
# assert result.z1.attrs == {}
assert result.z1.attrs == {}

# keyword takes precedence over global option
with set_options(keep_attrs=False):
Expand All @@ -794,8 +789,7 @@ def test_rolling_exp_keep_attrs(self, ds) -> None:
with set_options(keep_attrs=True):
result = ds.rolling_exp(time=10).mean(keep_attrs=False)
assert result.attrs == {}
# See above
# assert result.z1.attrs == {}
assert result.z1.attrs == {}

with pytest.warns(
UserWarning,
Expand Down

0 comments on commit d27bff4

Please sign in to comment.