diff --git a/doc/whats-new.rst b/doc/whats-new.rst index e4c3a4d533f..8684e9c60a7 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -77,6 +77,10 @@ Bug fixes By `Tom Nicholas `_. - Fix ``RasterioDeprecationWarning`` when using a ``vrt`` in ``open_rasterio``. (:issue:`3964`) By `Taher Chegini `_. +- Fix bug causing :py:meth:`DataArray.interpolate_na` to always drop attributes, + and added `keep_attrs` argument. (:issue:`3968`) + By `Tom Nicholas `_. + Documentation ~~~~~~~~~~~~~ diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 63cba53b689..9d45dd15887 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -2098,6 +2098,7 @@ def interpolate_na( max_gap: Union[ int, float, str, pd.Timedelta, np.timedelta64, datetime.timedelta ] = None, + keep_attrs: bool = None, **kwargs: Any, ) -> "DataArray": """Fill in NaNs by interpolating according to different methods. @@ -2152,6 +2153,10 @@ def interpolate_na( * x (x) int64 0 1 2 3 4 5 6 7 8 The gap lengths are 3-0 = 3; 6-3 = 3; and 8-6 = 2 respectively + keep_attrs : bool, default True + If True, the dataarray's attributes (`attrs`) will be copied from + the original object to the new one. If False, the new + object will be returned without attributes. kwargs : dict, optional parameters passed verbatim to the underlying interpolation function @@ -2174,6 +2179,7 @@ def interpolate_na( limit=limit, use_coordinate=use_coordinate, max_gap=max_gap, + keep_attrs=keep_attrs, **kwargs, ) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 40f010b3514..f973b4a5468 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -11,6 +11,7 @@ from .common import _contains_datetime_like_objects, ones_like from .computation import apply_ufunc from .duck_array_ops import dask_array_type, datetime_to_numeric, timedelta_to_numeric +from .options import _get_keep_attrs from .utils import OrderedSet, is_scalar from .variable import Variable, broadcast_variables @@ -294,6 +295,7 @@ def interp_na( method: str = "linear", limit: int = None, max_gap: Union[int, float, str, pd.Timedelta, np.timedelta64, dt.timedelta] = None, + keep_attrs: bool = None, **kwargs, ): """Interpolate values according to different methods. @@ -330,19 +332,22 @@ def interp_na( interp_class, kwargs = _get_interpolator(method, **kwargs) interpolator = partial(func_interpolate_na, interp_class, **kwargs) + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=True) + with warnings.catch_warnings(): warnings.filterwarnings("ignore", "overflow", RuntimeWarning) warnings.filterwarnings("ignore", "invalid value", RuntimeWarning) arr = apply_ufunc( interpolator, - index, self, + index, input_core_dims=[[dim], [dim]], output_core_dims=[[dim]], output_dtypes=[self.dtype], dask="parallelized", vectorize=True, - keep_attrs=True, + keep_attrs=keep_attrs, ).transpose(*self.dims) if limit is not None: @@ -359,8 +364,9 @@ def interp_na( return arr -def func_interpolate_na(interpolator, x, y, **kwargs): +def func_interpolate_na(interpolator, y, x, **kwargs): """helper function to apply interpolation along 1 dimension""" + # reversed arguments are so that attrs are preserved from da, not index # it would be nice if this wasn't necessary, works around: # "ValueError: assignment destination is read-only" in assignment below out = y.copy() diff --git a/xarray/tests/test_missing.py b/xarray/tests/test_missing.py index 35c71c2854c..731cd165244 100644 --- a/xarray/tests/test_missing.py +++ b/xarray/tests/test_missing.py @@ -231,6 +231,17 @@ def test_interpolate_kwargs(): assert_equal(actual, expected) +def test_interpolate_keep_attrs(): + vals = np.array([1, 2, 3, 4, 5, 6], dtype=np.float64) + mvals = vals.copy() + mvals[2] = np.nan + missing = xr.DataArray(mvals, dims="x") + missing.attrs = {"test": "value"} + + actual = missing.interpolate_na(dim="x", keep_attrs=True) + assert actual.attrs == {"test": "value"} + + def test_interpolate(): vals = np.array([1, 2, 3, 4, 5, 6], dtype=np.float64)