Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rolling keep_attrs & default True #4510

Merged
merged 20 commits into from
Nov 9, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ New Features
By `Miguel Jimenez <https://github.com/Mikejmnez>`_ and `Wei Ji Leong <https://github.com/weiji14>`_.
- Unary & binary operations follow the ``keep_attrs`` flag (:issue:`3490`, :issue:`4065`, :issue:`3433`, :issue:`3595`, :pull:`4195`).
By `Deepak Cherian <https://github.com/dcherian>`_.
- :py:attr:`DataArray.rolling` and :py:attr:`Dataset.rolling` now also respect the ``keep_attrs``
keyword for their ``DataArray``s (previously only the global attrs were retained). ``keep_attrs``
keewis marked this conversation as resolved.
Show resolved Hide resolved
is now set to ``True`` per default for rolling operations (:issue:`4497`).
By `Mathias Hauser <https://github.com/mathause>`_.

Bug fixes
~~~~~~~~~
Expand Down
6 changes: 2 additions & 4 deletions xarray/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,8 +812,8 @@ def rolling(
center : bool or mapping, default: False
Set the labels at the center of the window.
keep_attrs : bool, optional
If True, the object's attributes (`attrs`) will be copied from
the original object to the new one. If False (default), the new
If True (default), the object's attributes (`attrs`) will be copied
from the original object to the new one. If False, the new
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
If True (default), the object's attributes (`attrs`) will be copied
from the original object to the new one. If False, the new
If None (default) or True, the object's attributes (`attrs`) will be
copied from the original object to the new one. If False, the new

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd replace the (default) and the , optional in the type spec with , default: None

object will be returned without attributes.
**window_kwargs : optional
The keyword arguments form of ``dim``.
Expand Down Expand Up @@ -863,8 +863,6 @@ def rolling(
core.rolling.DataArrayRolling
core.rolling.DatasetRolling
"""
if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=False)

dim = either_dict_or_kwargs(dim, window_kwargs, "rolling")
return self._rolling_cls(
Expand Down
57 changes: 39 additions & 18 deletions xarray/core/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None
center : bool, default: False
Set the labels at the center of the window.
keep_attrs : bool, optional
If True, the object's attributes (`attrs`) will be copied from
the original object to the new one. If False (default), the new
If None (default) or True, the object's attributes (`attrs`) will be
copied from the original object to the new one. If False, the new
object will be returned without attributes.

Returns
Expand All @@ -89,7 +89,7 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None
self.min_periods = np.prod(self.window) if min_periods is None else min_periods

if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=False)
keep_attrs = _get_keep_attrs(default=True)
self.keep_attrs = keep_attrs

def __repr__(self):
Expand Down Expand Up @@ -181,8 +181,8 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None
center : bool, default: False
Set the labels at the center of the window.
keep_attrs : bool, optional
If True, the object's attributes (`attrs`) will be copied from
the original object to the new one. If False (default), the new
If None (default) or True, the object's attributes (`attrs`) will be
copied from the original object to the new one. If False, the new
object will be returned without attributes.

Returns
Expand All @@ -196,8 +196,6 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None
Dataset.rolling
Dataset.groupby
"""
if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=False)
super().__init__(
obj, windows, min_periods=min_periods, center=center, keep_attrs=keep_attrs
)
Expand Down Expand Up @@ -294,8 +292,14 @@ def construct(
window = self.obj.variable.rolling_window(
self.dim, self.window, window_dim, self.center, fill_value=fill_value
)

attrs = self.obj.attrs if self.keep_attrs else {}

result = DataArray(
window, dims=self.obj.dims + tuple(window_dim), coords=self.obj.coords
window,
dims=self.obj.dims + tuple(window_dim),
coords=self.obj.coords,
attrs=attrs,
)
return result.isel(
**{d: slice(None, None, s) for d, s in zip(self.dim, stride)}
Expand Down Expand Up @@ -354,14 +358,16 @@ def reduce(self, func, **kwargs):
for d in self.dim
}
windows = self.construct(rolling_dim)
result = windows.reduce(func, dim=list(rolling_dim.values()), **kwargs)
result = windows.reduce(
func, dim=list(rolling_dim.values()), keep_attrs=self.keep_attrs, **kwargs
mathause marked this conversation as resolved.
Show resolved Hide resolved
)

# Find valid windows based on count.
counts = self._counts()
return result.where(counts >= self.min_periods)

def _counts(self):
""" Number of non-nan entries in each rolling window. """
"""Number of non-nan entries in each rolling window."""

rolling_dim = {
d: utils.get_temp_dimname(self.obj.dims, f"_rolling_dim_{d}")
Expand Down Expand Up @@ -422,7 +428,10 @@ def _bottleneck_reduce(self, func, **kwargs):

if self.center[0]:
values = values[valid]
result = DataArray(values, self.obj.coords)

attrs = self.obj.attrs if self.keep_attrs else {}

result = DataArray(values, self.obj.coords, attrs=attrs)

return result

Expand Down Expand Up @@ -473,8 +482,8 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None
center : bool or mapping of hashable to bool, default: False
Set the labels at the center of the window.
keep_attrs : bool, optional
If True, the object's attributes (`attrs`) will be copied from
the original object to the new one. If False (default), the new
If None (default) or True, the object's attributes (`attrs`) will be
copied from the original object to the new one. If False, the new
object will be returned without attributes.

Returns
Expand Down Expand Up @@ -587,6 +596,15 @@ def construct(

from .dataset import Dataset

if keep_attrs is not None:
warnings.warn(
"Passing 'keep_attrs' to 'construct' is deprecated."
" Please pass 'keep_attrs' directly to ``rolling``",
FutureWarning,
)
else:
keep_attrs = self.keep_attrs

if window_dim is None:
if len(window_dim_kwargs) == 0:
raise ValueError(
Expand All @@ -599,12 +617,9 @@ def construct(
)
stride = self._mapping_to_list(stride, default=1)

if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=True)

dataset = {}
for key, da in self.obj.data_vars.items():
# keeps rollings only for the dataset depending on slf.dim
# keeps rollings only for the dataset depending on self.dim
dims = [d for d in self.dim if d in da.dims]
if len(dims) > 0:
wi = {d: window_dim[i] for i, d in enumerate(self.dim) if d in da.dims}
Expand All @@ -614,7 +629,13 @@ def construct(
)
else:
dataset[key] = da
return Dataset(dataset, coords=self.obj.coords).isel(

if not keep_attrs:
dataset[key].attrs = {}

attrs = self.obj.attrs if keep_attrs else {}

return Dataset(dataset, coords=self.obj.coords, attrs=attrs).isel(
**{d: slice(None, None, s) for d, s in zip(self.dim, stride)}
)

Expand Down
36 changes: 36 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -6295,6 +6295,7 @@ def test_rolling_properties(da):
# catching invalid args
with pytest.raises(ValueError, match="window must be > 0"):
da.rolling(time=-2)

with pytest.raises(ValueError, match="min_periods must be greater than zero"):
da.rolling(time=2, min_periods=0)

Expand Down Expand Up @@ -6538,6 +6539,41 @@ def test_ndrolling_construct(center, fill_value):
assert_allclose(actual, expected)


@pytest.mark.parametrize(
"funcname, argument",
[("reduce", (np.mean,)), ("mean", ()), ("construct", ("window_dim",))],
)
def test_rolling_keep_attrs(funcname, argument):

attrs_da = {"da_attr": "test"}

data = np.linspace(10, 15, 100)
coords = np.linspace(1, 10, 100)

ds = DataArray(
data,
dims=("coord"),
coords={"coord": coords},
attrs=attrs_da,
)

# attrs are now kept per default
func = getattr(ds.rolling(dim={"coord": 5}), funcname)
dat = func(*argument)
assert dat.attrs == attrs_da

# discard attrs
func = getattr(ds.rolling(dim={"coord": 5}, keep_attrs=False), funcname)
dat = func(*argument)
assert dat.attrs == {}

# test discard attrs using global option
with set_options(keep_attrs=False):
func = getattr(ds.rolling(dim={"coord": 5}), funcname)
dat = func(*argument)
assert dat.attrs == {}


def test_raise_no_warning_for_nan_in_binary_ops():
with pytest.warns(None) as record:
xr.DataArray([1, 2, np.NaN]) > 0
Expand Down
70 changes: 53 additions & 17 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5995,33 +5995,69 @@ def test_coarsen_keep_attrs():
xr.testing.assert_identical(ds, ds2)


def test_rolling_keep_attrs():
_attrs = {"units": "test", "long_name": "testing"}
@pytest.mark.parametrize(
"funcname, argument",
[("reduce", (np.mean,)), ("mean", ()), ("construct", ("window_dim",))],
)
def test_rolling_keep_attrs(funcname, argument):
global_attrs = {"units": "test", "long_name": "testing"}
attrs_da = {"da_attr": "test"}

var1 = np.linspace(10, 15, 100)
var2 = np.linspace(5, 10, 100)
data = np.linspace(10, 15, 100)
coords = np.linspace(1, 10, 100)

ds = Dataset(
data_vars={"var1": ("coord", var1), "var2": ("coord", var2)},
data_vars={"da": ("coord", data)},
coords={"coord": coords},
attrs=_attrs,
attrs=global_attrs,
)
ds.da.attrs = attrs_da

# Test dropped attrs
dat = ds.rolling(dim={"coord": 5}, min_periods=None, center=False).mean()
# attrs are now kept per default
func = getattr(ds.rolling(dim={"coord": 5}), funcname)
dat = func(*argument)
assert dat.attrs == global_attrs
assert dat.da.attrs == attrs_da

# discard attrs
func = getattr(ds.rolling(dim={"coord": 5}, keep_attrs=False), funcname)
dat = func(*argument)
assert dat.attrs == {}
assert dat.da.attrs == {}

# Test kept attrs using dataset keyword
dat = ds.rolling(
dim={"coord": 5}, min_periods=None, center=False, keep_attrs=True
).mean()
assert dat.attrs == _attrs
# test discard attrs using global option
with set_options(keep_attrs=False):
func = getattr(ds.rolling(dim={"coord": 5}), funcname)
dat = func(*argument)
assert dat.attrs == {}
assert dat.da.attrs == {}

# Test kept attrs using global option
with set_options(keep_attrs=True):
dat = ds.rolling(dim={"coord": 5}, min_periods=None, center=False).mean()
assert dat.attrs == _attrs

def test_rolling_keep_attrs_construct_deprecated():
global_attrs = {"units": "test", "long_name": "testing"}
attrs_da = {"da_attr": "test"}

data = np.linspace(10, 15, 100)
coords = np.linspace(1, 10, 100)

ds = Dataset(
data_vars={"da": ("coord", data)},
coords={"coord": coords},
attrs=global_attrs,
)
ds.da.attrs = attrs_da

# deprecated option
with pytest.warns(
FutureWarning, match="Passing 'keep_attrs' to 'construct' is deprecated"
):
dat = ds.rolling(dim={"coord": 5}, keep_attrs=True).construct(
"window_dim", keep_attrs=False
)

# takes precedence over 'keep_attrs' passed to rolling
assert dat.attrs == {}
assert dat.da.attrs == {}


def test_rolling_properties(ds):
Expand Down