diff --git a/doc/api.rst b/doc/api.rst index 1050211cabd..fd127c5f867 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -212,6 +212,7 @@ Reshaping and reorganizing Dataset.shift Dataset.roll Dataset.sortby + Dataset.broadcast_like DataArray ========= @@ -386,6 +387,7 @@ Reshaping and reorganizing DataArray.shift DataArray.roll DataArray.sortby + DataArray.broadcast_like .. _api.ufuncs: diff --git a/doc/whats-new.rst b/doc/whats-new.rst index a1b3e5416ca..85398996a26 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -21,6 +21,9 @@ v0.12.4 (unreleased) New functions/methods ~~~~~~~~~~~~~~~~~~~~~ +- Added :py:meth:`DataArray.broadcast_like` and :py:meth:`Dataset.broadcast_like`. + By `Deepak Cherian `_. + Enhancements ~~~~~~~~~~~~ @@ -48,6 +51,7 @@ New functions/methods (:issue:`3026`). By `Julia Kent `_. + Enhancements ~~~~~~~~~~~~ diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index 412560d914a..6d7fcb98a26 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -391,6 +391,58 @@ def reindex_variables( return reindexed, new_indexes +def _get_broadcast_dims_map_common_coords(args, exclude): + + common_coords = OrderedDict() + dims_map = OrderedDict() + for arg in args: + for dim in arg.dims: + if dim not in common_coords and dim not in exclude: + dims_map[dim] = arg.sizes[dim] + if dim in arg.coords: + common_coords[dim] = arg.coords[dim].variable + + return dims_map, common_coords + + +def _broadcast_helper(arg, exclude, dims_map, common_coords): + + from .dataarray import DataArray + from .dataset import Dataset + + def _set_dims(var): + # Add excluded dims to a copy of dims_map + var_dims_map = dims_map.copy() + for dim in exclude: + with suppress(ValueError): + # ignore dim not in var.dims + var_dims_map[dim] = var.shape[var.dims.index(dim)] + + return var.set_dims(var_dims_map) + + def _broadcast_array(array): + data = _set_dims(array.variable) + coords = OrderedDict(array.coords) + coords.update(common_coords) + return DataArray(data, coords, data.dims, name=array.name, + attrs=array.attrs) + + def _broadcast_dataset(ds): + data_vars = OrderedDict( + (k, _set_dims(ds.variables[k])) + for k in ds.data_vars) + coords = OrderedDict(ds.coords) + coords.update(common_coords) + return Dataset(data_vars, coords, ds.attrs) + + if isinstance(arg, DataArray): + return _broadcast_array(arg) + elif isinstance(arg, Dataset): + return _broadcast_dataset(arg) + else: + raise ValueError('all input must be Dataset or DataArray objects') + + def broadcast(*args, exclude=None): """Explicitly broadcast any number of DataArray or Dataset objects against one another. @@ -463,55 +515,16 @@ def broadcast(*args, exclude=None): a (x, y) int64 1 1 2 2 3 3 b (x, y) int64 5 6 5 6 5 6 """ - from .dataarray import DataArray - from .dataset import Dataset if exclude is None: exclude = set() args = align(*args, join='outer', copy=False, exclude=exclude) - common_coords = OrderedDict() - dims_map = OrderedDict() - for arg in args: - for dim in arg.dims: - if dim not in common_coords and dim not in exclude: - dims_map[dim] = arg.sizes[dim] - if dim in arg.coords: - common_coords[dim] = arg.coords[dim].variable - - def _set_dims(var): - # Add excluded dims to a copy of dims_map - var_dims_map = dims_map.copy() - for dim in exclude: - with suppress(ValueError): - # ignore dim not in var.dims - var_dims_map[dim] = var.shape[var.dims.index(dim)] - - return var.set_dims(var_dims_map) - - def _broadcast_array(array): - data = _set_dims(array.variable) - coords = OrderedDict(array.coords) - coords.update(common_coords) - return DataArray(data, coords, data.dims, name=array.name, - attrs=array.attrs) - - def _broadcast_dataset(ds): - data_vars = OrderedDict( - (k, _set_dims(ds.variables[k])) - for k in ds.data_vars) - coords = OrderedDict(ds.coords) - coords.update(common_coords) - return Dataset(data_vars, coords, ds.attrs) - + dims_map, common_coords = _get_broadcast_dims_map_common_coords( + args, exclude) result = [] for arg in args: - if isinstance(arg, DataArray): - result.append(_broadcast_array(arg)) - elif isinstance(arg, Dataset): - result.append(_broadcast_dataset(arg)) - else: - raise ValueError('all input must be Dataset or DataArray objects') + result.append(_broadcast_helper(arg, exclude, dims_map, common_coords)) return tuple(result) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index beaf148df6b..0a1f3934091 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -15,7 +15,9 @@ utils) from .accessor_dt import DatetimeAccessor from .accessor_str import StringAccessor -from .alignment import align, reindex_like_indexers +from .alignment import (align, _broadcast_helper, + _get_broadcast_dims_map_common_coords, + reindex_like_indexers) from .common import AbstractArray, DataWithCoords from .coordinates import ( DataArrayCoordinates, LevelCoordinatesSource, assert_coordinate_consistent, @@ -994,6 +996,29 @@ def sel_points(self, dim='points', method=None, tolerance=None, dim=dim, method=method, tolerance=tolerance, **indexers) return self._from_temp_dataset(ds) + def broadcast_like(self, + other: Union['DataArray', Dataset], + exclude=None) -> 'DataArray': + """Broadcast this DataArray against another Dataset or DataArray. + This is equivalent to xr.broadcast(other, self)[1] + + Parameters + ---------- + other : Dataset or DataArray + Object against which to broadcast this array. + exclude : sequence of str, optional + Dimensions that must not be broadcasted + """ + + if exclude is None: + exclude = set() + args = align(other, self, join='outer', copy=False, exclude=exclude) + + dims_map, common_coords = _get_broadcast_dims_map_common_coords( + args, exclude) + + return _broadcast_helper(self, exclude, dims_map, common_coords) + def reindex_like(self, other: Union['DataArray', Dataset], method: Optional[str] = None, tolerance=None, copy: bool = True, fill_value=dtypes.NA) -> 'DataArray': diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 94d6ec7651e..40f0c720612 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -17,7 +17,8 @@ from ..coding.cftimeindex import _parse_array_of_cftime_strings from . import (alignment, dtypes, duck_array_ops, formatting, groupby, indexing, ops, pdcompat, resample, rolling, utils) -from .alignment import align +from .alignment import (align, _broadcast_helper, + _get_broadcast_dims_map_common_coords) from .common import (ALL_DIMS, DataWithCoords, ImplementsDatasetReduce, _contains_datetime_like_objects) from .coordinates import (DatasetCoordinates, LevelCoordinatesSource, @@ -2027,6 +2028,30 @@ def sel_points(self, dim='points', method=None, tolerance=None, ) return self.isel_points(dim=dim, **pos_indexers) + def broadcast_like(self, + other: Union['Dataset', 'DataArray'], + exclude=None) -> 'Dataset': + """Broadcast this DataArray against another Dataset or DataArray. + This is equivalent to xr.broadcast(other, self)[1] + + Parameters + ---------- + other : Dataset or DataArray + Object against which to broadcast this array. + exclude : sequence of str, optional + Dimensions that must not be broadcasted + + """ + + if exclude is None: + exclude = set() + args = align(other, self, join='outer', copy=False, exclude=exclude) + + dims_map, common_coords = _get_broadcast_dims_map_common_coords( + args, exclude) + + return _broadcast_helper(self, exclude, dims_map, common_coords) + def reindex_like(self, other, method=None, tolerance=None, copy=True, fill_value=dtypes.NA): """Conform this object onto the indexes of another object, filling in diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 30385779e5f..63317519bc7 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -1255,6 +1255,21 @@ def test_coords_non_string(self): expected = DataArray(2, coords={1: 2}, name=1) assert_identical(actual, expected) + def test_broadcast_like(self): + original1 = DataArray(np.random.randn(5), + [('x', range(5))]) + + original2 = DataArray(np.random.randn(6), + [('y', range(6))]) + + expected1, expected2 = broadcast(original1, original2) + + assert_identical(original1.broadcast_like(original2), + expected1.transpose('y', 'x')) + + assert_identical(original2.broadcast_like(original1), + expected2) + def test_reindex_like(self): foo = DataArray(np.random.randn(5, 6), [('x', range(5)), ('y', range(6))]) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index c4925352bbc..db7c5a35de6 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -1560,6 +1560,21 @@ def test_sel(lab_indexer, pos_indexer, replaced_idx=False, assert_identical(mdata.sel(x={'one': 'a', 'two': 1}), mdata.sel(one='a', two=1)) + def test_broadcast_like(self): + original1 = DataArray(np.random.randn(5), + [('x', range(5))], name='a').to_dataset() + + original2 = DataArray(np.random.randn(6), + [('y', range(6))], name='b') + + expected1, expected2 = broadcast(original1, original2) + + assert_identical(original1.broadcast_like(original2), + expected1.transpose('y', 'x')) + + assert_identical(original2.broadcast_like(original1), + expected2) + def test_reindex_like(self): data = create_test_data() data['letters'] = ('dim3', 10 * ['a'])