Skip to content

Commit

Permalink
Rework broadcasting so that arguments to _wrapper are named
Browse files Browse the repository at this point in the history
  • Loading branch information
mgunyho committed May 27, 2023
1 parent d081ee6 commit 986fa1f
Showing 1 changed file with 41 additions and 32 deletions.
73 changes: 41 additions & 32 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8689,7 +8689,6 @@ def curvefit(
"""
from scipy.optimize import curve_fit

from xarray.core.alignment import broadcast
from xarray.core.computation import apply_ufunc
from xarray.core.dataarray import _THIS_ARRAY, DataArray

Expand Down Expand Up @@ -8751,32 +8750,37 @@ def curvefit(
f"dimensions {preserved_dims}."
)

# Broadcast all coords with each other
coords_ = broadcast(*coords_)
coords_ = [
coord.broadcast_like(self, exclude=preserved_dims) for coord in coords_
]
n_coords = len(coords_)

params, func_args = _get_func_args(func, param_names)
param_defaults, bounds_defaults = _initialize_curvefit_params(
params, p0, bounds, func_args
)
n_params = len(params)

def _wrapper(Y, *args, **kwargs):
# Wrap curve_fit with raveled coordinates and pointwise NaN handling
# *args contains:
# - the coordinates
# - initial guess
# - lower bounds
# - upper bounds
coords__ = args[:n_coords]
p0_ = args[n_coords + 0 * n_params : n_coords + 1 * n_params]
lb = args[n_coords + 1 * n_params : n_coords + 2 * n_params]
ub = args[n_coords + 2 * n_params :]

x = np.vstack([c.ravel() for c in coords__])

# Convert the coordinates, initial guess, and bounds to DataArrays, so
# that everything is fully broadcast against each other.
# This '+ "_"' is needed for the name because otherwise this would
# create a Dataset with no data variables, since the names of the
# variables would be just the names of the coordinates.
coords_ = Dataset({coord.name + "_": coord for coord in coords_}).to_array(
"coord"
)
param_defaults = Dataset(param_defaults).to_array("param")
bounds_defaults_lo = Dataset(
{name: bound[0] for name, bound in bounds_defaults.items()}
).to_array("param")
bounds_defaults_hi = Dataset(
{name: bound[1] for name, bound in bounds_defaults.items()}
).to_array("param")
n_params = len(param_defaults)

def _wrapper(
Y,
coords: np.ndarray,
p0: np.ndarray,
lb: np.ndarray,
ub: np.ndarray,
**kwargs,
):
x = np.vstack([c.ravel() for c in coords])
y = Y.ravel()
if skipna:
mask = np.all([np.any(~np.isnan(x), axis=0), ~np.isnan(y)], axis=0)
Expand All @@ -8787,7 +8791,7 @@ def _wrapper(Y, *args, **kwargs):
pcov = np.full([n_params, n_params], np.nan)
return popt, pcov
x = np.squeeze(x)
popt, pcov = curve_fit(func, x, y, p0=p0_, bounds=(lb, ub), **kwargs)
popt, pcov = curve_fit(func, x, y, p0=p0, bounds=(lb, ub), **kwargs)
return popt, pcov

result = type(self)()
Expand All @@ -8797,18 +8801,23 @@ def _wrapper(Y, *args, **kwargs):
else:
name = f"{str(name)}_"

input_core_dims = [reduce_dims_ for _ in range(n_coords + 1)]
input_core_dims.extend(
[[] for _ in range(3 * n_params)]
) # core_dims for p0 and bounds
input_core_dims = [
reduce_dims_, # input_core_dims for the data
# For the coords. Note that it's important that 'coord' comes
# first, otherwise arrays will be the wrong shape in _wrapper.
["coord"] + reduce_dims_,
["param"], # for the initial guess
["param"], # for upper bounds
["param"], # for lower bounds
]

popt, pcov = apply_ufunc(
_wrapper,
da,
*coords_,
*param_defaults.values(),
*[b[0] for b in bounds_defaults.values()],
*[b[1] for b in bounds_defaults.values()],
coords_,
param_defaults,
bounds_defaults_lo,
bounds_defaults_hi,
vectorize=True,
dask="parallelized",
input_core_dims=input_core_dims,
Expand Down

0 comments on commit 986fa1f

Please sign in to comment.