Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

einsum for xarray #1968

Merged
merged 14 commits into from
Mar 12, 2018
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Top-level functions
full_like
zeros_like
ones_like
dot

Dataset
=======
Expand Down
5 changes: 4 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,14 @@ Documentation
Enhancements
~~~~~~~~~~~~

- Addition of :py:func:`~xarray.dot`, equivalent to ``np.einsum``.
Also, :py:func:`~xarray.DataArray.dot` now supports ``dims`` option,
which specifies the dimensions to sum over.
(:issue:`1951`)
- Support lazy vectorized-indexing. After this change, flexible indexing such
as orthogonal/vectorized indexing, becomes possible for all the backend
arrays. Also, lazy ``transpose`` is now also supported. (:issue:`1897`)
By `Keisuke Fujii <https://github.com/fujiisoup>`_.

- Improve :py:func:`~xarray.DataArray.rolling` logic.
:py:func:`~xarray.DataArrayRolling` object now supports
:py:func:`~xarray.DataArrayRolling.construct` method that returns a view
Expand Down
2 changes: 1 addition & 1 deletion xarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .core.alignment import align, broadcast, broadcast_arrays
from .core.common import full_like, zeros_like, ones_like
from .core.combine import concat, auto_combine
from .core.computation import apply_ufunc, where
from .core.computation import apply_ufunc, where, dot
from .core.extensions import (register_dataarray_accessor,
register_dataset_accessor)
from .core.variable import as_variable, Variable, IndexVariable, Coordinate
Expand Down
104 changes: 102 additions & 2 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
import functools
import itertools
import operator
from collections import Counter

import numpy as np

from . import duck_array_ops, utils
from . import duck_array_ops, utils, dtypes
from .alignment import deep_align
from .merge import expand_and_merge_variables
from .pycompat import OrderedDict, dask_array_type
from .pycompat import OrderedDict, dask_array_type, basestring
from .utils import is_dict_like

_DEFAULT_FROZEN_SET = frozenset()
Expand Down Expand Up @@ -926,6 +927,105 @@ def earth_mover_distance(first_samples,
return apply_array_ufunc(func, *args, dask=dask)


def dot(*arrays, **kwargs):
""" dot(*arrays, dims=None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dot(*arrays, *, dims=None) is the way to write this with Python 3's keyword only arguments.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we would keep this as dot(*arrays, **kwargs) as we did not yet drop python 2 support?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was confused. def dot(*arrays, *, dims=None) is not valid syntax in Python 3, either. (There can only be one single *)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PEP3102 says we python 3 supports the form def dot(*arrays, dim=None).


Generalized dot product for xarray objects. Like np.einsum, but
provides a simpler interface based on array dimensions.

Parameters
----------
arrays: DataArray objects
Arrays to compute.
dims: str or tuple of strings, optional
Which dimensions to sum over.
If not speciified, then all the common dimensions are summed over.

Returns
-------
dot: DataArray

Examples
--------

>>> da_a = xr.DataArray(np.arange(3 * 4).reshape(3, 4), dims=['a', 'b'])
>>> da_b = xr.DataArray(np.arange(3 * 4 * 5).reshape(3, 4, 5),
>>> dims=['a', 'b', 'c'])
>>> da_c = xr.DataArray(np.arange(5 * 6).reshape(5, 6), dims=['c', 'd'])
>>>
>>> xr.dot(da_a, da_b, dims=['a', 'b']).dims
('c', )
>>> xr.dot(da_a, da_b, dims=['a']).dims
('b', 'c')
>>> xr.dot(da_a, da_b, da_c, dims=['b', 'c']).dims
('a', 'd')
"""
from .dataarray import DataArray

dims = kwargs.pop('dims', None)
if len(kwargs) > 0:
raise TypeError('Invalid keyward arguments {} are given'.format(
list(kwargs.keys())))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if you write xr.dot()? I suppose we still need to raise an error for 0 arguments.

if any(not isinstance(arr, DataArray) for arr in arrays):
raise TypeError('Only xr.DataArray and xr.Variable are supported.')

if isinstance(dims, basestring):
dims = [dims]

common_dims = set(arrays[0].dims)
all_dims = []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it work to make all_dims a set instead of a list? I think that would be slightly more efficient.

Copy link
Member Author

@fujiisoup fujiisoup Mar 8, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to keep the occurrence order in all_dims, so that to move input_core_dims positions back to the original position.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, sounds good.

for arr in arrays[1:]:
common_dims = common_dims.intersection(set(arr.dims))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a slightly different choice of default dimensions than np.einsum:

  • np.einsum sums over any dimensions that are defined in two over more inputs.
  • This sums only over dimensions that are defined on all inputs.

Should we switch this behavior to match einsum?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be slightly more efficient to construct common_dims with a single call to intersection?

e.g.,
common_dims = set.intersection(*[set(arr.dims) for arr in arrays])

for arr in arrays:
all_dims += [d for d in arr.dims if d not in all_dims]

einsum_axes = 'abcdefghijklmnopqrstuvwxyz'
dim_map = {d: einsum_axes[i] for i, d in enumerate(all_dims)}

if dims is None:
# find dimensions that occur more than one times
dim_counts = Counter()
for arr in arrays:
dim_counts.update(arr.dims)
dims = [d for d, c in dim_counts.items() if c > 1]

broadcast_dims = [d for d in common_dims if d not in dims]
input_core_dims = []
output_core_dims = [[]]
for arr in arrays:
input_core_dims.append([d for d in arr.dims if d not in
broadcast_dims])
output_core_dims[0] += [d for d in arr.dims if d not in
output_core_dims[0] + dims + broadcast_dims]

subscripts_list = ['...' + ''.join([dim_map[d] for d in ds]) for ds
in input_core_dims]
subscripts = ','.join(subscripts_list)
subscripts += '->...' + ''.join([dim_map[d] for d in output_core_dims[0]])

# dtype estimation is necessary for dask='parallelized'
out_dtype = dtypes.result_type(*arrays)

# we use tensordot if possible, because it is more efficient for dask
if len(broadcast_dims) == 0 and len(arrays) == 2:
axes = [[arr.get_axis_num(d) for d in arr.dims if d in dims]
for arr in arrays]
return apply_ufunc(duck_array_ops.tensordot, *arrays, dask='allowed',
input_core_dims=input_core_dims,
output_core_dims=output_core_dims,
kwargs={'axes': axes})
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I added a path for tensordot, which dask can compute more efficiently.


# subscripts should be passed as arg, not as a kwargs. We need
# to construct a partial function for parallelized computation.
func = functools.partial(np.einsum, subscripts)
result = apply_ufunc(func, *arrays,
input_core_dims=input_core_dims,
output_core_dims=output_core_dims,
dask='parallelized', output_dtypes=[out_dtype])
return result.transpose(*[d for d in all_dims if d in result.dims])


def where(cond, x, y):
"""Return elements from `x` or `y` depending on `cond`.

Expand Down
26 changes: 7 additions & 19 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import pandas as pd

from . import duck_array_ops, groupby, indexing, ops, resample, rolling, utils
from . import computation, groupby, indexing, ops, resample, rolling, utils
from ..plot.plot import _PlotMethods
from .accessors import DatetimeAccessor
from .alignment import align, reindex_like_indexers
Expand Down Expand Up @@ -1926,7 +1926,7 @@ def real(self):
def imag(self):
return self._replace(self.variable.imag)

def dot(self, other):
def dot(self, other, dims=None):
"""Perform dot product of two DataArrays along their shared dims.

Equivalent to taking taking tensordot over all shared dims.
Expand All @@ -1935,6 +1935,9 @@ def dot(self, other):
----------
other : DataArray
The other array with which the dot product is performed.
dims: list of strings, optional
Along which dimensions to be summed over. Default all the common
dimensions are summed over.

Returns
-------
Expand All @@ -1943,6 +1946,7 @@ def dot(self, other):

See also
--------
dot
numpy.tensordot

Examples
Expand All @@ -1968,23 +1972,7 @@ def dot(self, other):
if not isinstance(other, DataArray):
raise TypeError('dot only operates on DataArrays.')

# sum over the common dims
dims = set(self.dims) & set(other.dims)
if len(dims) == 0:
raise ValueError('DataArrays have no shared dimensions over which '
'to perform dot.')

self, other = align(self, other, join='inner', copy=False)

axes = (self.get_axis_num(dims), other.get_axis_num(dims))
new_data = duck_array_ops.tensordot(self.data, other.data, axes=axes)

new_coords = self.coords.merge(other.coords)
new_coords = new_coords.drop([d for d in dims if d in new_coords])
new_dims = ([d for d in self.dims if d not in dims] +
[d for d in other.dims if d not in dims])

return type(self)(new_data, new_coords.variables, new_dims)
return computation.dot(self, other, dims=dims)

def sortby(self, variables, ascending=True):
"""
Expand Down
83 changes: 82 additions & 1 deletion xarray/tests/test_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
join_dict_keys, ordered_set_intersection, ordered_set_union, result_name,
unified_dim_sizes)

from . import raises_regex, requires_dask
from . import raises_regex, requires_dask, has_dask


def assert_identical(a, b):
Expand Down Expand Up @@ -744,6 +744,87 @@ def test_vectorize_dask():
assert_identical(expected, actual)


@pytest.mark.parametrize('dask', [True, False])
def test_dot(dask):
if not has_dask:
pytest.skip('test for dask.')

a = np.arange(30 * 4).reshape(30, 4)
b = np.arange(30 * 4 * 5).reshape(30, 4, 5)
c = np.arange(5 * 60).reshape(5, 60)
da_a = xr.DataArray(a, dims=['a', 'b'],
coords={'a': np.linspace(0, 1, 30)})
da_b = xr.DataArray(b, dims=['a', 'b', 'c'],
coords={'a': np.linspace(0, 1, 30)})
da_c = xr.DataArray(c, dims=['c', 'e'])
if dask:
da_a = da_a.chunk({'a': 3})
da_b = da_b.chunk({'a': 3})
da_c = da_c.chunk({'c': 3})

actual = xr.dot(da_a, da_b, dims=['a', 'b'])
assert actual.dims == ('c', )
assert (actual.data == np.einsum('ij,ijk->k', a, b)).all()
assert isinstance(actual.variable.data, type(da_a.variable.data))

actual = xr.dot(da_a, da_b)
assert actual.dims == ('c', )
assert (actual.data == np.einsum('ij,ijk->k', a, b)).all()
assert isinstance(actual.variable.data, type(da_a.variable.data))

# for only a single array is passed without dims argument, just return
# as is
actual = xr.dot(da_a)
assert da_a.identical(actual)

if dask:
da_a = da_a.chunk({'a': 3})
da_b = da_b.chunk({'a': 3})
actual = xr.dot(da_a, da_b, dims=['b'])
assert actual.dims == ('a', 'c')
assert (actual.data == np.einsum('ij,ijk->ik', a, b)).all()
assert isinstance(actual.variable.data, type(da_a.variable.data))

pytest.skip('dot for dask array requires rechunking for core '
'dimensions.')

# following requires rechunking
actual = xr.dot(da_a, da_b, dims=['b'])
assert actual.dims == ('a', 'c')
assert (actual.data == np.einsum('ij,ijk->ik', a, b)).all()

actual = xr.dot(da_a, da_b, dims='b')
assert actual.dims == ('a', 'c')
assert (actual.data == np.einsum('ij,ijk->ik', a, b)).all()

actual = xr.dot(da_a, da_b, dims='a')
assert actual.dims == ('b', 'c')
assert (actual.data == np.einsum('ij,ijk->jk', a, b)).all()

actual = xr.dot(da_a, da_b, dims='c')
assert actual.dims == ('a', 'b')
assert (actual.data == np.einsum('ij,ijk->ij', a, b)).all()

actual = xr.dot(da_a, da_b, da_c, dims=['a', 'b'])
assert actual.dims == ('c', 'e')
assert (actual.data == np.einsum('ij,ijk,kl->kl ', a, b, c)).all()

# default dims
actual = xr.dot(da_a, da_b, da_c)
assert actual.dims == ('e', )
assert (actual.data == np.einsum('ij,ijk,kl->l ', a, b, c)).all()

# 1 array summation
actual = xr.dot(da_a, dims='a')
assert actual.dims == ('b', )
assert (actual.data == np.einsum('ij->j ', a)).all()

with pytest.raises(TypeError):
actual = xr.dot(da_a, dims='a', invalid=None)
with pytest.raises(TypeError):
actual = xr.dot(da_a.to_dataset(name='da'), dims='a')


def test_where():
cond = xr.DataArray([True, False], dims='x')
actual = xr.where(cond, 1, 0)
Expand Down
2 changes: 0 additions & 2 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3200,8 +3200,6 @@ def test_dot(self):
da.dot(dm.to_dataset(name='dm'))
with pytest.raises(TypeError):
da.dot(dm.values)
with raises_regex(ValueError, 'no shared dimensions'):
da.dot(DataArray(1))

def test_binary_op_join_setting(self):
dim = 'x'
Expand Down