Skip to content

Commit

Permalink
Improve interp performance (pydata#7843)
Browse files Browse the repository at this point in the history
  • Loading branch information
Illviljan authored and dstansby committed Jun 28, 2023
1 parent 030e5c1 commit 4fc9777
Showing 1 changed file with 9 additions and 12 deletions.
21 changes: 9 additions & 12 deletions xarray/core/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,7 @@ def interp(var, indexes_coords, method: InterpOptions, **kwargs):
var.transpose(*original_dims).data, x, destination, method, kwargs
)

result = Variable(new_dims, interped, attrs=var.attrs)
result = Variable(new_dims, interped, attrs=var.attrs, fastpath=True)

# dimension of the output array
out_dims: OrderedSet = OrderedSet()
Expand All @@ -648,7 +648,8 @@ def interp(var, indexes_coords, method: InterpOptions, **kwargs):
out_dims.update(indexes_coords[d][1].dims)
else:
out_dims.add(d)
result = result.transpose(*out_dims)
if len(out_dims) > 1:
result = result.transpose(*out_dims)
return result


Expand Down Expand Up @@ -709,28 +710,24 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs):
]
new_x_arginds = [item for pair in new_x_arginds for item in pair]

args = (
var,
range(ndim),
*x_arginds,
*new_x_arginds,
)
args = (var, range(ndim), *x_arginds, *new_x_arginds)

_, rechunked = chunkmanager.unify_chunks(*args)

args = tuple(elem for pair in zip(rechunked, args[1::2]) for elem in pair)

new_x = rechunked[1 + (len(rechunked) - 1) // 2 :]

new_x0_chunks = new_x[0].chunks
new_x0_shape = new_x[0].shape
new_x0_chunks_is_not_none = new_x0_chunks is not None
new_axes = {
ndim + i: new_x[0].chunks[i]
if new_x[0].chunks is not None
else new_x[0].shape[i]
ndim + i: new_x0_chunks[i] if new_x0_chunks_is_not_none else new_x0_shape[i]
for i in range(new_x[0].ndim)
}

# if useful, re-use localize for each chunk of new_x
localize = (method in ["linear", "nearest"]) and (new_x[0].chunks is not None)
localize = (method in ["linear", "nearest"]) and new_x0_chunks_is_not_none

# scipy.interpolate.interp1d always forces to float.
# Use the same check for blockwise as well:
Expand Down

0 comments on commit 4fc9777

Please sign in to comment.