From 986fa1f47894e47b32187222312a046bde3d7786 Mon Sep 17 00:00:00 2001 From: mgunyho <20118130+mgunyho@users.noreply.github.com> Date: Sat, 27 May 2023 12:27:04 +0300 Subject: [PATCH] Rework broadcasting so that arguments to _wrapper are named --- xarray/core/dataset.py | 73 ++++++++++++++++++++++++------------------ 1 file changed, 41 insertions(+), 32 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index c5d5ae2e2f7..30cfe3f9fa2 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -8689,7 +8689,6 @@ def curvefit( """ from scipy.optimize import curve_fit - from xarray.core.alignment import broadcast from xarray.core.computation import apply_ufunc from xarray.core.dataarray import _THIS_ARRAY, DataArray @@ -8751,32 +8750,37 @@ def curvefit( f"dimensions {preserved_dims}." ) - # Broadcast all coords with each other - coords_ = broadcast(*coords_) - coords_ = [ - coord.broadcast_like(self, exclude=preserved_dims) for coord in coords_ - ] - n_coords = len(coords_) - params, func_args = _get_func_args(func, param_names) param_defaults, bounds_defaults = _initialize_curvefit_params( params, p0, bounds, func_args ) - n_params = len(params) - - def _wrapper(Y, *args, **kwargs): - # Wrap curve_fit with raveled coordinates and pointwise NaN handling - # *args contains: - # - the coordinates - # - initial guess - # - lower bounds - # - upper bounds - coords__ = args[:n_coords] - p0_ = args[n_coords + 0 * n_params : n_coords + 1 * n_params] - lb = args[n_coords + 1 * n_params : n_coords + 2 * n_params] - ub = args[n_coords + 2 * n_params :] - - x = np.vstack([c.ravel() for c in coords__]) + + # Convert the coordinates, initial guess, and bounds to DataArrays, so + # that everything is fully broadcast against each other. + # This '+ "_"' is needed for the name because otherwise this would + # create a Dataset with no data variables, since the names of the + # variables would be just the names of the coordinates. + coords_ = Dataset({coord.name + "_": coord for coord in coords_}).to_array( + "coord" + ) + param_defaults = Dataset(param_defaults).to_array("param") + bounds_defaults_lo = Dataset( + {name: bound[0] for name, bound in bounds_defaults.items()} + ).to_array("param") + bounds_defaults_hi = Dataset( + {name: bound[1] for name, bound in bounds_defaults.items()} + ).to_array("param") + n_params = len(param_defaults) + + def _wrapper( + Y, + coords: np.ndarray, + p0: np.ndarray, + lb: np.ndarray, + ub: np.ndarray, + **kwargs, + ): + x = np.vstack([c.ravel() for c in coords]) y = Y.ravel() if skipna: mask = np.all([np.any(~np.isnan(x), axis=0), ~np.isnan(y)], axis=0) @@ -8787,7 +8791,7 @@ def _wrapper(Y, *args, **kwargs): pcov = np.full([n_params, n_params], np.nan) return popt, pcov x = np.squeeze(x) - popt, pcov = curve_fit(func, x, y, p0=p0_, bounds=(lb, ub), **kwargs) + popt, pcov = curve_fit(func, x, y, p0=p0, bounds=(lb, ub), **kwargs) return popt, pcov result = type(self)() @@ -8797,18 +8801,23 @@ def _wrapper(Y, *args, **kwargs): else: name = f"{str(name)}_" - input_core_dims = [reduce_dims_ for _ in range(n_coords + 1)] - input_core_dims.extend( - [[] for _ in range(3 * n_params)] - ) # core_dims for p0 and bounds + input_core_dims = [ + reduce_dims_, # input_core_dims for the data + # For the coords. Note that it's important that 'coord' comes + # first, otherwise arrays will be the wrong shape in _wrapper. + ["coord"] + reduce_dims_, + ["param"], # for the initial guess + ["param"], # for upper bounds + ["param"], # for lower bounds + ] popt, pcov = apply_ufunc( _wrapper, da, - *coords_, - *param_defaults.values(), - *[b[0] for b in bounds_defaults.values()], - *[b[1] for b in bounds_defaults.values()], + coords_, + param_defaults, + bounds_defaults_lo, + bounds_defaults_hi, vectorize=True, dask="parallelized", input_core_dims=input_core_dims,