diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 2ad2a426532..7e76066cb33 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -162,9 +162,10 @@ New Features Enhancements ~~~~~~~~~~~~ - Performance improvement of :py:meth:`DataArray.interp` and :py:func:`Dataset.interp` - For orthogonal linear- and nearest-neighbor interpolation, we do 1d-interpolation sequentially - rather than interpolating in multidimensional space. (:issue:`2223`) + We performs independant interpolation sequentially rather than interpolating in + one large multidimensional space. (:issue:`2223`) By `Keisuke Fujii `_. +- :py:meth:`DataArray.interp` now support interpolations over chunked dimensions (:pull:`4155`). By `Alexandre Poux `_. - Major performance improvement for :py:meth:`Dataset.from_dataframe` when the dataframe has a MultiIndex (:pull:`4184`). By `Stephan Hoyer `_. diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 59d4f777c73..a6bed408164 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -544,15 +544,6 @@ def _get_valid_fill_mask(arr, dim, limit): ) <= limit -def _assert_single_chunk(var, axes): - for axis in axes: - if len(var.chunks[axis]) > 1 or var.chunks[axis][0] < var.shape[axis]: - raise NotImplementedError( - "Chunking along the dimension to be interpolated " - "({}) is not yet supported.".format(axis) - ) - - def _localize(var, indexes_coords): """ Speed up for linear and nearest neighbor method. Only consider a subspace that is needed for the interpolation @@ -617,49 +608,42 @@ def interp(var, indexes_coords, method, **kwargs): if not indexes_coords: return var.copy() - # simple speed up for the local interpolation - if method in ["linear", "nearest"]: - var, indexes_coords = _localize(var, indexes_coords) - # default behavior kwargs["bounds_error"] = kwargs.get("bounds_error", False) - # check if the interpolation can be done in orthogonal manner - if ( - len(indexes_coords) > 1 - and method in ["linear", "nearest"] - and all(dest[1].ndim == 1 for dest in indexes_coords.values()) - and len(set([d[1].dims[0] for d in indexes_coords.values()])) - == len(indexes_coords) - ): - # interpolate sequentially - for dim, dest in indexes_coords.items(): - var = interp(var, {dim: dest}, method, **kwargs) - return var - - # target dimensions - dims = list(indexes_coords) - x, new_x = zip(*[indexes_coords[d] for d in dims]) - destination = broadcast_variables(*new_x) - - # transpose to make the interpolated axis to the last position - broadcast_dims = [d for d in var.dims if d not in dims] - original_dims = broadcast_dims + dims - new_dims = broadcast_dims + list(destination[0].dims) - interped = interp_func( - var.transpose(*original_dims).data, x, destination, method, kwargs - ) + result = var + # decompose the interpolation into a succession of independant interpolation + for indexes_coords in decompose_interp(indexes_coords): + var = result + + # simple speed up for the local interpolation + if method in ["linear", "nearest"]: + var, indexes_coords = _localize(var, indexes_coords) + + # target dimensions + dims = list(indexes_coords) + x, new_x = zip(*[indexes_coords[d] for d in dims]) + destination = broadcast_variables(*new_x) + + # transpose to make the interpolated axis to the last position + broadcast_dims = [d for d in var.dims if d not in dims] + original_dims = broadcast_dims + dims + new_dims = broadcast_dims + list(destination[0].dims) + interped = interp_func( + var.transpose(*original_dims).data, x, destination, method, kwargs + ) - result = Variable(new_dims, interped, attrs=var.attrs) + result = Variable(new_dims, interped, attrs=var.attrs) - # dimension of the output array - out_dims = OrderedSet() - for d in var.dims: - if d in dims: - out_dims.update(indexes_coords[d][1].dims) - else: - out_dims.add(d) - return result.transpose(*tuple(out_dims)) + # dimension of the output array + out_dims = OrderedSet() + for d in var.dims: + if d in dims: + out_dims.update(indexes_coords[d][1].dims) + else: + out_dims.add(d) + result = result.transpose(*tuple(out_dims)) + return result def interp_func(var, x, new_x, method, kwargs): @@ -706,21 +690,51 @@ def interp_func(var, x, new_x, method, kwargs): if isinstance(var, dask_array_type): import dask.array as da - _assert_single_chunk(var, range(var.ndim - len(x), var.ndim)) - chunks = var.chunks[: -len(x)] + new_x[0].shape - drop_axis = range(var.ndim - len(x), var.ndim) - new_axis = range(var.ndim - len(x), var.ndim - len(x) + new_x[0].ndim) - return da.map_blocks( - _interpnd, + nconst = var.ndim - len(x) + + out_ind = list(range(nconst)) + list(range(var.ndim, var.ndim + new_x[0].ndim)) + + # blockwise args format + x_arginds = [[_x, (nconst + index,)] for index, _x in enumerate(x)] + x_arginds = [item for pair in x_arginds for item in pair] + new_x_arginds = [ + [_x, [var.ndim + index for index in range(_x.ndim)]] for _x in new_x + ] + new_x_arginds = [item for pair in new_x_arginds for item in pair] + + args = ( var, - x, - new_x, - func, - kwargs, + range(var.ndim), + *x_arginds, + *new_x_arginds, + ) + + _, rechunked = da.unify_chunks(*args) + + args = tuple([elem for pair in zip(rechunked, args[1::2]) for elem in pair]) + + new_x = rechunked[1 + (len(rechunked) - 1) // 2 :] + + new_axes = { + var.ndim + i: new_x[0].chunks[i] + if new_x[0].chunks is not None + else new_x[0].shape[i] + for i in range(new_x[0].ndim) + } + + # if usefull, re-use localize for each chunk of new_x + localize = (method in ["linear", "nearest"]) and (new_x[0].chunks is not None) + + return da.blockwise( + _dask_aware_interpnd, + out_ind, + *args, + interp_func=func, + interp_kwargs=kwargs, + localize=localize, + concatenate=True, dtype=var.dtype, - chunks=chunks, - new_axis=new_axis, - drop_axis=drop_axis, + new_axes=new_axes, ) return _interpnd(var, x, new_x, func, kwargs) @@ -751,3 +765,67 @@ def _interpnd(var, x, new_x, func, kwargs): # move back the interpolation axes to the last position rslt = rslt.transpose(range(-rslt.ndim + 1, 1)) return rslt.reshape(rslt.shape[:-1] + new_x[0].shape) + + +def _dask_aware_interpnd(var, *coords, interp_func, interp_kwargs, localize=True): + """Wrapper for `_interpnd` through `blockwise` + + The first half arrays in `coords` are original coordinates, + the other half are destination coordinates + """ + n_x = len(coords) // 2 + nconst = len(var.shape) - n_x + + # _interpnd expect coords to be Variables + x = [Variable([f"dim_{nconst + dim}"], _x) for dim, _x in enumerate(coords[:n_x])] + new_x = [ + Variable([f"dim_{len(var.shape) + dim}" for dim in range(len(_x.shape))], _x) + for _x in coords[n_x:] + ] + + if localize: + # _localize expect var to be a Variable + var = Variable([f"dim_{dim}" for dim in range(len(var.shape))], var) + + indexes_coords = {_x.dims[0]: (_x, _new_x) for _x, _new_x in zip(x, new_x)} + + # simple speed up for the local interpolation + var, indexes_coords = _localize(var, indexes_coords) + x, new_x = zip(*[indexes_coords[d] for d in indexes_coords]) + + # put var back as a ndarray + var = var.data + + return _interpnd(var, x, new_x, interp_func, interp_kwargs) + + +def decompose_interp(indexes_coords): + """Decompose the interpolation into a succession of independant interpolation keeping the order""" + + dest_dims = [ + dest[1].dims if dest[1].ndim > 0 else [dim] + for dim, dest in indexes_coords.items() + ] + partial_dest_dims = [] + partial_indexes_coords = {} + for i, index_coords in enumerate(indexes_coords.items()): + partial_indexes_coords.update([index_coords]) + + if i == len(dest_dims) - 1: + break + + partial_dest_dims += [dest_dims[i]] + other_dims = dest_dims[i + 1 :] + + s_partial_dest_dims = {dim for dims in partial_dest_dims for dim in dims} + s_other_dims = {dim for dims in other_dims for dim in dims} + + if not s_partial_dest_dims.intersection(s_other_dims): + # this interpolation is orthogonal to the rest + + yield partial_indexes_coords + + partial_dest_dims = [] + partial_indexes_coords = {} + + yield partial_indexes_coords diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index e0da3f1527f..d7e88735fbf 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -3147,7 +3147,8 @@ def test_upsample_interpolate_regression_1605(self): @requires_dask @requires_scipy - def test_upsample_interpolate_dask(self): + @pytest.mark.parametrize("chunked_time", [True, False]) + def test_upsample_interpolate_dask(self, chunked_time): from scipy.interpolate import interp1d xs = np.arange(6) @@ -3158,6 +3159,8 @@ def test_upsample_interpolate_dask(self): data = np.tile(z, (6, 3, 1)) array = DataArray(data, {"time": times, "x": xs, "y": ys}, ("x", "y", "time")) chunks = {"x": 2, "y": 1} + if chunked_time: + chunks["time"] = 3 expected_times = times.to_series().resample("1H").asfreq().index # Split the times into equal sub-intervals to simulate the 6 hour @@ -3185,13 +3188,6 @@ def test_upsample_interpolate_dask(self): # done here due to floating point arithmetic assert_allclose(expected, actual, rtol=1e-16) - # Check that an error is raised if an attempt is made to interpolate - # over a chunked dimension - with raises_regex( - NotImplementedError, "Chunking along the dimension to be interpolated" - ): - array.chunk({"time": 1}).resample(time="1H").interpolate("linear") - def test_align(self): array = DataArray( np.random.random((6, 8)), coords={"x": list("abcdef")}, dims=["x", "y"] diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py index 7a0dda216e2..17e418c3731 100644 --- a/xarray/tests/test_interp.py +++ b/xarray/tests/test_interp.py @@ -1,9 +1,18 @@ +from itertools import combinations, permutations + import numpy as np import pandas as pd import pytest import xarray as xr -from xarray.tests import assert_allclose, assert_equal, requires_cftime, requires_scipy +from xarray.tests import ( + assert_allclose, + assert_equal, + assert_identical, + requires_cftime, + requires_dask, + requires_scipy, +) from ..coding.cftimeindex import _parse_array_of_cftime_strings from . import has_dask, has_scipy @@ -63,12 +72,6 @@ def test_interpolate_1d(method, dim, case): da = get_example_data(case) xdest = np.linspace(0.0, 0.9, 80) - - if dim == "y" and case == 1: - with pytest.raises(NotImplementedError): - actual = da.interp(method=method, **{dim: xdest}) - pytest.skip("interpolation along chunked dimension is " "not yet supported") - actual = da.interp(method=method, **{dim: xdest}) # scipy interpolation for the reference @@ -376,8 +379,6 @@ def test_errors(use_dask): # invalid method with pytest.raises(ValueError): da.interp(x=[2, 0], method="boo") - with pytest.raises(ValueError): - da.interp(x=[2, 0], y=2, method="cubic") with pytest.raises(ValueError): da.interp(y=[2, 0], method="boo") @@ -717,3 +718,120 @@ def test_decompose(method): actual = da.interp(x=x_new, y=y_new, method=method).drop(("x", "y")) expected = da.interp(x=x_broadcast, y=y_broadcast, method=method).drop(("x", "y")) assert_allclose(actual, expected) + + +@requires_scipy +@requires_dask +@pytest.mark.parametrize( + "method", ["linear", "nearest", "zero", "slinear", "quadratic", "cubic"] +) +@pytest.mark.parametrize("chunked", [True, False]) +@pytest.mark.parametrize( + "data_ndim,interp_ndim,nscalar", + [ + (data_ndim, interp_ndim, nscalar) + for data_ndim in range(1, 4) + for interp_ndim in range(1, data_ndim + 1) + for nscalar in range(0, interp_ndim + 1) + ], +) +def test_interpolate_chunk_1d(method, data_ndim, interp_ndim, nscalar, chunked): + """Interpolate nd array with multiple independant indexers + + It should do a series of 1d interpolation + """ + + # 3d non chunked data + x = np.linspace(0, 1, 5) + y = np.linspace(2, 4, 7) + z = np.linspace(-0.5, 0.5, 11) + da = xr.DataArray( + data=np.sin(x[:, np.newaxis, np.newaxis]) + * np.cos(y[:, np.newaxis]) + * np.exp(z), + coords=[("x", x), ("y", y), ("z", z)], + ) + kwargs = {"fill_value": "extrapolate"} + + # choose the data dimensions + for data_dims in permutations(da.dims, data_ndim): + + # select only data_ndim dim + da = da.isel( # take the middle line + {dim: len(da.coords[dim]) // 2 for dim in da.dims if dim not in data_dims} + ) + + # chunk data + da = da.chunk(chunks={dim: i + 1 for i, dim in enumerate(da.dims)}) + + # choose the interpolation dimensions + for interp_dims in permutations(da.dims, interp_ndim): + # choose the scalar interpolation dimensions + for scalar_dims in combinations(interp_dims, nscalar): + dest = {} + for dim in interp_dims: + if dim in scalar_dims: + # take the middle point + dest[dim] = 0.5 * (da.coords[dim][0] + da.coords[dim][-1]) + else: + # pick some points, including outside the domain + before = 2 * da.coords[dim][0] - da.coords[dim][1] + after = 2 * da.coords[dim][-1] - da.coords[dim][-2] + + dest[dim] = np.linspace(before, after, len(da.coords[dim]) * 13) + if chunked: + dest[dim] = xr.DataArray(data=dest[dim], dims=[dim]) + dest[dim] = dest[dim].chunk(2) + actual = da.interp(method=method, **dest, kwargs=kwargs) + expected = da.compute().interp(method=method, **dest, kwargs=kwargs) + + assert_identical(actual, expected) + + # all the combinations are usualy not necessary + break + break + break + + +@requires_scipy +@requires_dask +@pytest.mark.parametrize("method", ["linear", "nearest"]) +def test_interpolate_chunk_advanced(method): + """Interpolate nd array with an nd indexer sharing coordinates.""" + # Create original array + x = np.linspace(-1, 1, 5) + y = np.linspace(-1, 1, 7) + z = np.linspace(-1, 1, 11) + t = np.linspace(0, 1, 13) + q = np.linspace(0, 1, 17) + da = xr.DataArray( + data=np.sin(x[:, np.newaxis, np.newaxis, np.newaxis, np.newaxis]) + * np.cos(y[:, np.newaxis, np.newaxis, np.newaxis]) + * np.exp(z[:, np.newaxis, np.newaxis]) + * t[:, np.newaxis] + + q, + dims=("x", "y", "z", "t", "q"), + coords={"x": x, "y": y, "z": z, "t": t, "q": q, "label": "dummy_attr"}, + ) + + # Create indexer into `da` with shared coordinate ("full-twist" Möbius strip) + theta = np.linspace(0, 2 * np.pi, 5) + w = np.linspace(-0.25, 0.25, 7) + r = xr.DataArray( + data=1 + w[:, np.newaxis] * np.cos(theta), coords=[("w", w), ("theta", theta)], + ) + + x = r * np.cos(theta) + y = r * np.sin(theta) + z = xr.DataArray( + data=w[:, np.newaxis] * np.sin(theta), coords=[("w", w), ("theta", theta)], + ) + + kwargs = {"fill_value": None} + expected = da.interp(t=0.5, x=x, y=y, z=z, kwargs=kwargs, method=method) + + da = da.chunk(2) + x = x.chunk(1) + z = z.chunk(3) + actual = da.interp(t=0.5, x=x, y=y, z=z, kwargs=kwargs, method=method) + assert_identical(actual, expected)