Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite interp to use apply_ufunc #9881

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
55 changes: 18 additions & 37 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2921,19 +2921,11 @@ def _validate_interp_indexers(
"""Variant of _validate_indexers to be used for interpolation"""
for k, v in self._validate_indexers(indexers):
if isinstance(v, Variable):
if v.ndim == 1:
yield k, v.to_index_variable()
else:
yield k, v
elif isinstance(v, int):
yield k, v
elif is_scalar(v):
yield k, Variable((), v, attrs=self.coords[k].attrs)
elif isinstance(v, np.ndarray):
if v.ndim == 0:
yield k, Variable((), v, attrs=self.coords[k].attrs)
elif v.ndim == 1:
yield k, IndexVariable((k,), v, attrs=self.coords[k].attrs)
else:
raise AssertionError() # Already tested by _validate_indexers
yield k, Variable(dims=(k,), data=v, attrs=self.coords[k].attrs)
else:
raise TypeError(type(v))

Expand Down Expand Up @@ -4127,18 +4119,6 @@ def interp(

coords = either_dict_or_kwargs(coords, coords_kwargs, "interp")
indexers = dict(self._validate_interp_indexers(coords))

if coords:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handled by vectorize=True. This is possibly a perf regression with numpy arrays, but a massive improvement with chunked arrays.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For posterity the bad thing about this approach is that it can greatly expand the number of core dimensions for the problem, limiting the potential for parallelism.

Consider the problem in #6799 (comment). In the following, dimension names are listed out in [].

da[time, q, lat, lon].interp(q=bar[lat,lon]) gets rewritten to da[time,q,lat,lon].interp(q=bar[lat, lon], lat=lat[lat], lon=lon[lon]) which thanks to our automatic rechunking, makes dask merge chunks in lat, lon too, for no benefit.

# This avoids broadcasting over coordinates that are both in
# the original array AND in the indexing array. It essentially
# forces interpolation along the shared coordinates.
sdims = (
set(self.dims)
.intersection(*[set(nx.dims) for nx in indexers.values()])
.difference(coords.keys())
)
indexers.update({d: self.variables[d] for d in sdims})

obj = self if assume_sorted else self.sortby(list(coords))

def maybe_variable(obj, k):
Expand Down Expand Up @@ -4169,16 +4149,18 @@ def _validate_interp_indexer(x, new_x):
for k, v in indexers.items()
}

# optimization: subset to coordinate range of the target index
if method in ["linear", "nearest"]:
for k, v in validated_indexers.items():
obj, newidx = missing._localize(obj, {k: v})
validated_indexers[k] = newidx[k]

# optimization: create dask coordinate arrays once per Dataset
# rather than once per Variable when dask.array.unify_chunks is called later
# GH4739
if obj.__dask_graph__():
has_chunked_array = bool(
any(is_chunked_array(v._data) for v in obj._variables.values())
)
if has_chunked_array:
# optimization: subset to coordinate range of the target index
if method in ["linear", "nearest"]:
for k, v in validated_indexers.items():
obj, newidx = missing._localize(obj, {k: v})
validated_indexers[k] = newidx[k]
# optimization: create dask coordinate arrays once per Dataset
# rather than once per Variable when dask.array.unify_chunks is called later
# GH4739
dask_indexers = {
k: (index.to_base_variable().chunk(), dest.to_base_variable().chunk())
for k, (index, dest) in validated_indexers.items()
Expand All @@ -4190,10 +4172,9 @@ def _validate_interp_indexer(x, new_x):
if name in indexers:
continue

if is_duck_dask_array(var.data):
use_indexers = dask_indexers
else:
use_indexers = validated_indexers
use_indexers = (
dask_indexers if is_duck_dask_array(var.data) else validated_indexers
)

dtype_kind = var.dtype.kind
if dtype_kind in "uifc":
Expand Down
Loading
Loading