Skip to content

Commit

Permalink
Clear up broadcasting
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Dec 14, 2024
1 parent be5c783 commit a5e1854
Showing 1 changed file with 38 additions and 24 deletions.
62 changes: 38 additions & 24 deletions xarray/core/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import datetime as dt
import itertools
import warnings
from collections import ChainMap
from collections.abc import Callable, Generator, Hashable, Sequence
from functools import partial
from numbers import Number
Expand Down Expand Up @@ -710,59 +711,66 @@ def interpolate_variable(
func, kwargs = _get_interpolator_nd(method, **kwargs)

in_coords, result_coords = zip(*(v for v in indexes_coords.values()), strict=True)
# broadcast out manually to minize confusing behaviour
broadcast_result_coords = broadcast_variables(*result_coords)
result_dims = broadcast_result_coords[0].dims

# input coordinates along which we are interpolation are core dimensions
# the corresponding output coordinates may or may not have the same name,
# so `all_in_core_dims` is also `exclude_dims`
all_in_core_dims = set(indexes_coords)

result_dims = OrderedSet(itertools.chain(*(_.dims for _ in result_coords)))
result_sizes = ChainMap(*(_.sizes for _ in result_coords))

# any dimensions on the output that are present on the input, but are not being
# interpolated along are broadcast or loop dimensions along which we automatically
# vectorize. Consider the problem in
# https://github.com/pydata/xarray/issues/6799#issuecomment-2474126217
# interpolated along are dimensions along which we automatically vectorize.
# Consider the problem in https://github.com/pydata/xarray/issues/6799#issuecomment-2474126217
# In the following, dimension names are listed out in [].
# # da[time, q, lat, lon].interp(q=bar[lat,lon]). Here `lat`, `lon`
# are input dimensions, present on the output, along which we vectorize.
# We track these as "result broadcast dimensions".
# are input dimensions, present on the output, but are not the coordinates
# we are explicitly interpolating. These are the dimensions along which we vectorize.
# `q` is the only input core dimensions, and changes size (disappears)
# so it is in exclude_dims.
result_broadcast_dims = set(
itertools.chain(dim for dim in result_dims if dim not in all_in_core_dims)
)
vectorize_dims = (result_dims - all_in_core_dims) & set(var.dims)

# remove any output broadcast dimensions from the list of core dimensions
output_core_dims = tuple(d for d in result_dims if d not in result_broadcast_dims)
output_core_dims = tuple(d for d in result_dims if d not in vectorize_dims)
input_core_dims = (
# all coordinates on the input that we interpolate along
[tuple(indexes_coords)]
# the input coordinates are always 1D at the moment, so we just need to list out their names
+ [tuple(_.dims) for _ in in_coords]
# The last set of inputs are the coordinates we are interpolating to.
# These have been broadcast already for ease.
+ [output_core_dims] * len(result_coords)
+ [
tuple(d for d in coord.dims if d not in vectorize_dims)
for coord in result_coords
]
)
output_sizes = {k: broadcast_result_coords[0].sizes[k] for k in output_core_dims}
output_sizes = {k: result_sizes[k] for k in output_core_dims}

# scipy.interpolate.interp1d always forces to float.
dtype = float if not issubclass(var.dtype.type, np.inexact) else var.dtype
result = apply_ufunc(
_interpnd,
var,
*in_coords,
*broadcast_result_coords,
*result_coords,
input_core_dims=input_core_dims,
output_core_dims=[output_core_dims],
exclude_dims=all_in_core_dims,
dask="parallelized",
kwargs=dict(interp_func=func, interp_kwargs=kwargs),
kwargs=dict(
interp_func=func,
interp_kwargs=kwargs,
# we leave broadcasting up to dask if possible
# but we need broadcasted values in _interpnd, so propagate that
# context (dimension names), and broadcast there
# This would be unnecessary if we could tell apply_ufunc
# to insert size-1 broadcast dimensions
result_coord_core_dims=input_core_dims[-len(result_coords) :],
),
# TODO: deprecate and have the user rechunk themselves
dask_gufunc_kwargs=dict(output_sizes=output_sizes, allow_rechunk=True),
output_dtypes=[dtype],
# if there are any broadcast dims on the result, we must vectorize on them
vectorize=bool(result_broadcast_dims),
vectorize=bool(vectorize_dims),
keep_attrs=True,
)
return result
Expand All @@ -787,7 +795,11 @@ def _interp1d(


def _interpnd(
data: np.ndarray, *coords: np.ndarray, interp_func: InterpCallable, interp_kwargs
data: np.ndarray,
*coords: np.ndarray,
interp_func: InterpCallable,
interp_kwargs,
result_coord_core_dims,
) -> np.ndarray:
"""
Core nD array interpolation routine.
Expand All @@ -801,10 +813,12 @@ def _interpnd(
# Convert everything to Variables, since that makes applying
# `_localize` and `_floatize_x` much easier
x = [Variable([f"dim_{nconst + dim}"], _x) for dim, _x in enumerate(coords[:n_x])]
new_x = [
Variable([f"dim_{ndim + dim}" for dim in range(_x.ndim)], _x)
for _x in coords[n_x:]
]
new_x = broadcast_variables(
*(
Variable(dims, _x)
for dims, _x in zip(result_coord_core_dims, coords[n_x:], strict=True)
)
)
var = Variable([f"dim_{dim}" for dim in range(ndim)], data)

if interp_kwargs.get("method") in ["linear", "nearest"]:
Expand Down

0 comments on commit a5e1854

Please sign in to comment.