Skip to content

Commit

Permalink
Add broadcast_like. (#3086)
Browse files Browse the repository at this point in the history
* Add broadcast_like.

Closes #2885

* lint.

* lint2

* Use a helper function

* lint
  • Loading branch information
dcherian authored and shoyer committed Jul 14, 2019
1 parent c4497ff commit 9438390
Show file tree
Hide file tree
Showing 7 changed files with 143 additions and 44 deletions.
2 changes: 2 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ Reshaping and reorganizing
Dataset.shift
Dataset.roll
Dataset.sortby
Dataset.broadcast_like

DataArray
=========
Expand Down Expand Up @@ -386,6 +387,7 @@ Reshaping and reorganizing
DataArray.shift
DataArray.roll
DataArray.sortby
DataArray.broadcast_like

.. _api.ufuncs:

Expand Down
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ This release increases the minimum required Python version from 3.5.0 to 3.5.3
New functions/methods
~~~~~~~~~~~~~~~~~~~~~

- Added :py:meth:`DataArray.broadcast_like` and :py:meth:`Dataset.broadcast_like`.
By `Deepak Cherian <https://github.com/dcherian>`_.

Enhancements
~~~~~~~~~~~~

Expand Down Expand Up @@ -54,6 +57,7 @@ New functions/methods
(:issue:`3026`).
By `Julia Kent <https://github.com/jukent>`_.


Enhancements
~~~~~~~~~~~~

Expand Down
97 changes: 55 additions & 42 deletions xarray/core/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,58 @@ def reindex_variables(
return reindexed, new_indexes


def _get_broadcast_dims_map_common_coords(args, exclude):

common_coords = OrderedDict()
dims_map = OrderedDict()
for arg in args:
for dim in arg.dims:
if dim not in common_coords and dim not in exclude:
dims_map[dim] = arg.sizes[dim]
if dim in arg.coords:
common_coords[dim] = arg.coords[dim].variable

return dims_map, common_coords


def _broadcast_helper(arg, exclude, dims_map, common_coords):

from .dataarray import DataArray
from .dataset import Dataset

def _set_dims(var):
# Add excluded dims to a copy of dims_map
var_dims_map = dims_map.copy()
for dim in exclude:
with suppress(ValueError):
# ignore dim not in var.dims
var_dims_map[dim] = var.shape[var.dims.index(dim)]

return var.set_dims(var_dims_map)

def _broadcast_array(array):
data = _set_dims(array.variable)
coords = OrderedDict(array.coords)
coords.update(common_coords)
return DataArray(data, coords, data.dims, name=array.name,
attrs=array.attrs)

def _broadcast_dataset(ds):
data_vars = OrderedDict(
(k, _set_dims(ds.variables[k]))
for k in ds.data_vars)
coords = OrderedDict(ds.coords)
coords.update(common_coords)
return Dataset(data_vars, coords, ds.attrs)

if isinstance(arg, DataArray):
return _broadcast_array(arg)
elif isinstance(arg, Dataset):
return _broadcast_dataset(arg)
else:
raise ValueError('all input must be Dataset or DataArray objects')


def broadcast(*args, exclude=None):
"""Explicitly broadcast any number of DataArray or Dataset objects against
one another.
Expand Down Expand Up @@ -463,55 +515,16 @@ def broadcast(*args, exclude=None):
a (x, y) int64 1 1 2 2 3 3
b (x, y) int64 5 6 5 6 5 6
"""
from .dataarray import DataArray
from .dataset import Dataset

if exclude is None:
exclude = set()
args = align(*args, join='outer', copy=False, exclude=exclude)

common_coords = OrderedDict()
dims_map = OrderedDict()
for arg in args:
for dim in arg.dims:
if dim not in common_coords and dim not in exclude:
dims_map[dim] = arg.sizes[dim]
if dim in arg.coords:
common_coords[dim] = arg.coords[dim].variable

def _set_dims(var):
# Add excluded dims to a copy of dims_map
var_dims_map = dims_map.copy()
for dim in exclude:
with suppress(ValueError):
# ignore dim not in var.dims
var_dims_map[dim] = var.shape[var.dims.index(dim)]

return var.set_dims(var_dims_map)

def _broadcast_array(array):
data = _set_dims(array.variable)
coords = OrderedDict(array.coords)
coords.update(common_coords)
return DataArray(data, coords, data.dims, name=array.name,
attrs=array.attrs)

def _broadcast_dataset(ds):
data_vars = OrderedDict(
(k, _set_dims(ds.variables[k]))
for k in ds.data_vars)
coords = OrderedDict(ds.coords)
coords.update(common_coords)
return Dataset(data_vars, coords, ds.attrs)

dims_map, common_coords = _get_broadcast_dims_map_common_coords(
args, exclude)
result = []
for arg in args:
if isinstance(arg, DataArray):
result.append(_broadcast_array(arg))
elif isinstance(arg, Dataset):
result.append(_broadcast_dataset(arg))
else:
raise ValueError('all input must be Dataset or DataArray objects')
result.append(_broadcast_helper(arg, exclude, dims_map, common_coords))

return tuple(result)

Expand Down
27 changes: 26 additions & 1 deletion xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
utils)
from .accessor_dt import DatetimeAccessor
from .accessor_str import StringAccessor
from .alignment import align, reindex_like_indexers
from .alignment import (align, _broadcast_helper,
_get_broadcast_dims_map_common_coords,
reindex_like_indexers)
from .common import AbstractArray, DataWithCoords
from .coordinates import (
DataArrayCoordinates, LevelCoordinatesSource, assert_coordinate_consistent,
Expand Down Expand Up @@ -993,6 +995,29 @@ def sel_points(self, dim='points', method=None, tolerance=None,
dim=dim, method=method, tolerance=tolerance, **indexers)
return self._from_temp_dataset(ds)

def broadcast_like(self,
other: Union['DataArray', Dataset],
exclude=None) -> 'DataArray':
"""Broadcast this DataArray against another Dataset or DataArray.
This is equivalent to xr.broadcast(other, self)[1]
Parameters
----------
other : Dataset or DataArray
Object against which to broadcast this array.
exclude : sequence of str, optional
Dimensions that must not be broadcasted
"""

if exclude is None:
exclude = set()
args = align(other, self, join='outer', copy=False, exclude=exclude)

dims_map, common_coords = _get_broadcast_dims_map_common_coords(
args, exclude)

return _broadcast_helper(self, exclude, dims_map, common_coords)

def reindex_like(self, other: Union['DataArray', Dataset],
method: Optional[str] = None, tolerance=None,
copy: bool = True, fill_value=dtypes.NA) -> 'DataArray':
Expand Down
27 changes: 26 additions & 1 deletion xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from ..coding.cftimeindex import _parse_array_of_cftime_strings
from . import (alignment, dtypes, duck_array_ops, formatting, groupby,
indexing, ops, pdcompat, resample, rolling, utils)
from .alignment import align
from .alignment import (align, _broadcast_helper,
_get_broadcast_dims_map_common_coords)
from .common import (ALL_DIMS, DataWithCoords, ImplementsDatasetReduce,
_contains_datetime_like_objects)
from .coordinates import (DatasetCoordinates, LevelCoordinatesSource,
Expand Down Expand Up @@ -2016,6 +2017,30 @@ def sel_points(self, dim='points', method=None, tolerance=None,
)
return self.isel_points(dim=dim, **pos_indexers)

def broadcast_like(self,
other: Union['Dataset', 'DataArray'],
exclude=None) -> 'Dataset':
"""Broadcast this DataArray against another Dataset or DataArray.
This is equivalent to xr.broadcast(other, self)[1]
Parameters
----------
other : Dataset or DataArray
Object against which to broadcast this array.
exclude : sequence of str, optional
Dimensions that must not be broadcasted
"""

if exclude is None:
exclude = set()
args = align(other, self, join='outer', copy=False, exclude=exclude)

dims_map, common_coords = _get_broadcast_dims_map_common_coords(
args, exclude)

return _broadcast_helper(self, exclude, dims_map, common_coords)

def reindex_like(self, other, method=None, tolerance=None, copy=True,
fill_value=dtypes.NA):
"""Conform this object onto the indexes of another object, filling in
Expand Down
15 changes: 15 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1255,6 +1255,21 @@ def test_coords_non_string(self):
expected = DataArray(2, coords={1: 2}, name=1)
assert_identical(actual, expected)

def test_broadcast_like(self):
original1 = DataArray(np.random.randn(5),
[('x', range(5))])

original2 = DataArray(np.random.randn(6),
[('y', range(6))])

expected1, expected2 = broadcast(original1, original2)

assert_identical(original1.broadcast_like(original2),
expected1.transpose('y', 'x'))

assert_identical(original2.broadcast_like(original1),
expected2)

def test_reindex_like(self):
foo = DataArray(np.random.randn(5, 6),
[('x', range(5)), ('y', range(6))])
Expand Down
15 changes: 15 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1560,6 +1560,21 @@ def test_sel(lab_indexer, pos_indexer, replaced_idx=False,
assert_identical(mdata.sel(x={'one': 'a', 'two': 1}),
mdata.sel(one='a', two=1))

def test_broadcast_like(self):
original1 = DataArray(np.random.randn(5),
[('x', range(5))], name='a').to_dataset()

original2 = DataArray(np.random.randn(6),
[('y', range(6))], name='b')

expected1, expected2 = broadcast(original1, original2)

assert_identical(original1.broadcast_like(original2),
expected1.transpose('y', 'x'))

assert_identical(original2.broadcast_like(original1),
expected2)

def test_reindex_like(self):
data = create_test_data()
data['letters'] = ('dim3', 10 * ['a'])
Expand Down

0 comments on commit 9438390

Please sign in to comment.