From 220ebccb8120288f44bfbf79e9beb2867a4e826d Mon Sep 17 00:00:00 2001 From: Keisuke Fujii Date: Tue, 6 Mar 2018 23:11:06 +0900 Subject: [PATCH 01/11] einsum for xarray --- doc/whats-new.rst | 6 +++ xarray/__init__.py | 2 +- xarray/core/computation.py | 82 +++++++++++++++++++++++++++++++- xarray/core/dataarray.py | 31 +++++------- xarray/tests/test_computation.py | 43 +++++++++++++++++ 5 files changed, 143 insertions(+), 21 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index ab667ceba3f..f229e87c4ef 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -38,6 +38,12 @@ Documentation Enhancements ~~~~~~~~~~~~ +- Addition of :py:func:`~xarray.dot`, which is equivalent to ``np.einsum``. + Also, :py:func:`~xarray.DataArray.dot` now supports ``dims`` option, + which specifies along which dimensions to be summed over. + (:issue:`1951`) + By `Keisuke Fujii `_. + - Improve :py:func:`~xarray.DataArray.rolling` logic. :py:func:`~xarray.DataArrayRolling` object now supports :py:func:`~xarray.DataArrayRolling.construct` method that returns a view diff --git a/xarray/__init__.py b/xarray/__init__.py index 3e80acd1572..1a2bf3fe283 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -6,7 +6,7 @@ from .core.alignment import align, broadcast, broadcast_arrays from .core.common import full_like, zeros_like, ones_like from .core.combine import concat, auto_combine -from .core.computation import apply_ufunc, where +from .core.computation import apply_ufunc, where, dot from .core.extensions import (register_dataarray_accessor, register_dataset_accessor) from .core.variable import as_variable, Variable, IndexVariable, Coordinate diff --git a/xarray/core/computation.py b/xarray/core/computation.py index b7590ab6b4b..57c85a91de9 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -12,7 +12,7 @@ from . import duck_array_ops, utils from .alignment import deep_align from .merge import expand_and_merge_variables -from .pycompat import OrderedDict, dask_array_type +from .pycompat import OrderedDict, dask_array_type, basestring from .utils import is_dict_like _DEFAULT_FROZEN_SET = frozenset() @@ -926,6 +926,86 @@ def earth_mover_distance(first_samples, return apply_array_ufunc(func, *args, dask=dask) +def dot(*args, **kwargs): + """ dot(*arrays, dims=None) + + einsum for xarray object. + + Parameters + ---------- + arrays: arrays to compute + dims: tuple of strings, optional + Along which dimensions to be summed over. + If None is provided, then all the common dimensions are summed over. + + Returns + ------- + dot: same type to input. + + Examples + -------- + + >>> da_a = xr.DataArray(np.arange(3 * 4).reshape(3, 4), dims=['a', 'b']) + >>> da_b = xr.DataArray(np.arange(3 * 4 * 5).reshape(3, 4, 5), + dims=['a', 'b', 'c']) + >>> da_c = xr.DataArray(np.arange(5 * 6).reshape(5, 6), dims=['c', 'd']) + + >>> dot(da_a, da_b, dims=['a', 'b']).dims + ('c', ) + >>> dot(da_a, da_b, dims=['a']).dims + ('b', 'c') + >>> dot(da_a, da_b, da_c, dims=['b', 'c']).dims + ('a', 'd') + """ + dims = kwargs.pop('dims', None) + arrays = args + if dims is None and isinstance(args[-1], (list, tuple, basestring)): + dims = args[-1] + arrays = args[:-1] + + if len(arrays) < 2: + raise TypeError('More than two arrays must be provided') + + if any(not hasattr(arr, 'dims') for arr in arrays): + raise TypeError('Only xr.DataArray and xr.Variable are supported.') + + if isinstance(dims, basestring): + dims = [dims] + + common_dims = set(arrays[0].dims) + for arr in arrays[1:]: + common_dims = common_dims.intersection(set(arr.dims)) + + if dims is None: + dims = list(common_dims) + + broadcast_dims = [d for d in common_dims if d not in dims] + input_core_dims = [] + output_core_dims = [[]] + all_dims = [] + + for arr in arrays: + input_core_dims.append([d for d in arr.dims if d not in + broadcast_dims]) + output_core_dims[0] += [d for d in arr.dims if d not in + output_core_dims[0] + dims + broadcast_dims] + all_dims += [d for d in arr.dims if d not in all_dims] + + einsum_axes = 'abcdefghijklmnopqrstuvwxyz' + dim_map = {d: einsum_axes[i] for i, d in enumerate(all_dims)} + + subscripts = '' + for ds in input_core_dims: + subscripts += '...' + ''.join([dim_map[d] for d in ds]) + ',' + subscripts = subscripts[:-1] # remove last comma + subscripts += '->...' + ''.join([dim_map[d] for d in output_core_dims[0]]) + + result = apply_ufunc(np.einsum, subscripts, *arrays, + input_core_dims=[[]] + input_core_dims, + output_core_dims=output_core_dims, dask='allowed') + return result.transpose(*[d for d in all_dims if d in result.dims]) + + def where(cond, x, y): """Return elements from `x` or `y` depending on `cond`. diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 8c0360df8a9..f6586ca33c5 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -6,7 +6,7 @@ import numpy as np import pandas as pd -from . import duck_array_ops, groupby, indexing, ops, resample, rolling, utils +from . import computation, groupby, indexing, ops, resample, rolling, utils from ..plot.plot import _PlotMethods from .accessors import DatetimeAccessor from .alignment import align, reindex_like_indexers @@ -1926,7 +1926,7 @@ def real(self): def imag(self): return self._replace(self.variable.imag) - def dot(self, other): + def dot(self, other, dims=None): """Perform dot product of two DataArrays along their shared dims. Equivalent to taking taking tensordot over all shared dims. @@ -1935,6 +1935,9 @@ def dot(self, other): ---------- other : DataArray The other array with which the dot product is performed. + dims: list of strings, optional + Along which dimensions to be summed over. Default all the common + dimensions are summed over. Returns ------- @@ -1943,6 +1946,7 @@ def dot(self, other): See also -------- + dot numpy.tensordot Examples @@ -1968,23 +1972,12 @@ def dot(self, other): if not isinstance(other, DataArray): raise TypeError('dot only operates on DataArrays.') - # sum over the common dims - dims = set(self.dims) & set(other.dims) - if len(dims) == 0: - raise ValueError('DataArrays have no shared dimensions over which ' - 'to perform dot.') - - self, other = align(self, other, join='inner', copy=False) - - axes = (self.get_axis_num(dims), other.get_axis_num(dims)) - new_data = duck_array_ops.tensordot(self.data, other.data, axes=axes) - - new_coords = self.coords.merge(other.coords) - new_coords = new_coords.drop([d for d in dims if d in new_coords]) - new_dims = ([d for d in self.dims if d not in dims] + - [d for d in other.dims if d not in dims]) - - return type(self)(new_data, new_coords.variables, new_dims) + # backward compat: if there is no shared dimension, we rais an Errror + shared_dims = set(self.dims) & set(other.dims) + if len(shared_dims) == 0: + raise ValueError('no shared dimensions. Given {} and {}.'.format( + self.dims, other.dims)) + return computation.dot(self, other, dims=dims) def sortby(self, variables, ascending=True): """ diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index ebd51d04857..97bcb422882 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -744,6 +744,49 @@ def test_vectorize_dask(): assert_identical(expected, actual) +@pytest.mark.parametrize('dask', [False, True]) +def test_dot(dask): + a = np.arange(3 * 4).reshape(3, 4) + b = np.arange(3 * 4 * 5).reshape(3, 4, 5) + c = np.arange(5 * 6).reshape(5, 6) + da_a = xr.DataArray(a, dims=['a', 'b']) + da_b = xr.DataArray(b, dims=['a', 'b', 'c']) + da_c = xr.DataArray(c, dims=['c', 'e']) + + if dask: + da_a = da_a.chunk({'b': 2}) + da_b = da_b.chunk({'b': 2}) + da_c = da_c.chunk({'e': 3}) + + actual = xr.dot(da_a, da_b, ['a', 'b']) + assert actual.dims == ('c', ) + assert (actual.data == np.einsum('ij,ijk->k', a, b)).all() + + actual = xr.dot(da_a, da_b) + assert actual.dims == ('c', ) + assert (actual.data == np.einsum('ij,ijk->k', a, b)).all() + + actual = xr.dot(da_a, da_b, ['b']) + assert actual.dims == ('a', 'c') + assert (actual.data == np.einsum('ij,ijk->ik', a, b)).all() + + actual = xr.dot(da_a, da_b, 'b') + assert actual.dims == ('a', 'c') + assert (actual.data == np.einsum('ij,ijk->ik', a, b)).all() + + actual = xr.dot(da_a, da_b, 'a') + assert actual.dims == ('b', 'c') + assert (actual.data == np.einsum('ij,ijk->jk', a, b)).all() + + actual = xr.dot(da_a, da_b, 'c') + assert actual.dims == ('a', 'b') + assert (actual.data == np.einsum('ij,ijk->ij', a, b)).all() + + actual = xr.dot(da_a, da_b, da_c, dims=['a', 'b']) + assert actual.dims == ('c', 'e') + assert (actual.data == np.einsum('ij,ijk,kl->kl ', a, b, c)).all() + + def test_where(): cond = xr.DataArray([True, False], dims='x') actual = xr.where(cond, 1, 0) From 4239ac6edadbabf134b62ebb5aa4d7f5efa70733 Mon Sep 17 00:00:00 2001 From: Keisuke Fujii Date: Tue, 6 Mar 2018 23:16:59 +0900 Subject: [PATCH 02/11] whats new --- doc/whats-new.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index f229e87c4ef..a1a15193158 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -38,9 +38,9 @@ Documentation Enhancements ~~~~~~~~~~~~ -- Addition of :py:func:`~xarray.dot`, which is equivalent to ``np.einsum``. +- Addition of :py:func:`~xarray.dot`, equivalent to ``np.einsum``. Also, :py:func:`~xarray.DataArray.dot` now supports ``dims`` option, - which specifies along which dimensions to be summed over. + which specifies the dimensions to be summed over. (:issue:`1951`) By `Keisuke Fujii `_. From 0f472a261bf08595f23f771d1f3f846e0da37f63 Mon Sep 17 00:00:00 2001 From: Keisuke Fujii Date: Wed, 7 Mar 2018 15:33:26 +0900 Subject: [PATCH 03/11] Support dask for xr.dot. --- doc/api.rst | 1 + xarray/core/computation.py | 55 +++++++++++++++++++++----------- xarray/core/dataarray.py | 5 --- xarray/tests/test_computation.py | 44 +++++++++++++++++-------- xarray/tests/test_dataarray.py | 2 -- 5 files changed, 67 insertions(+), 40 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index ae4803e5e62..1814b874b3e 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -24,6 +24,7 @@ Top-level functions full_like zeros_like ones_like + dot Dataset ======= diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 57c85a91de9..c8c72fbab95 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -9,7 +9,7 @@ import numpy as np -from . import duck_array_ops, utils +from . import duck_array_ops, utils, dtypes from .alignment import deep_align from .merge import expand_and_merge_variables from .pycompat import OrderedDict, dask_array_type, basestring @@ -926,17 +926,19 @@ def earth_mover_distance(first_samples, return apply_array_ufunc(func, *args, dask=dask) -def dot(*args, **kwargs): - """ dot(*arrays, dims=None) +def dot(*arrays, **kwargs): + """ dot(*arrays, *, dims=None) - einsum for xarray object. + einsum for xarray object, but providing simpler interface based on + the array dimensions. Parameters ---------- - arrays: arrays to compute + arrays: multiple DataArrays + arrays to compute. dims: tuple of strings, optional Along which dimensions to be summed over. - If None is provided, then all the common dimensions are summed over. + If not speciified, then all the common dimensions are summed over. Returns ------- @@ -957,16 +959,15 @@ def dot(*args, **kwargs): >>> dot(da_a, da_b, da_c, dims=['b', 'c']).dims ('a', 'd') """ + from .dataarray import DataArray + from .variable import Variable + dims = kwargs.pop('dims', None) - arrays = args - if dims is None and isinstance(args[-1], (list, tuple, basestring)): - dims = args[-1] - arrays = args[:-1] if len(arrays) < 2: - raise TypeError('More than two arrays must be provided') + raise TypeError('More than one arrays must be provided') - if any(not hasattr(arr, 'dims') for arr in arrays): + if any(not isinstance(arr, (DataArray, Variable)) for arr in arrays): raise TypeError('Only xr.DataArray and xr.Variable are supported.') if isinstance(dims, basestring): @@ -980,6 +981,7 @@ def dot(*args, **kwargs): dims = list(common_dims) broadcast_dims = [d for d in common_dims if d not in dims] + input_core_dims = [] output_core_dims = [[]] all_dims = [] @@ -994,15 +996,30 @@ def dot(*args, **kwargs): einsum_axes = 'abcdefghijklmnopqrstuvwxyz' dim_map = {d: einsum_axes[i] for i, d in enumerate(all_dims)} - subscripts = '' - for ds in input_core_dims: - subscripts += '...' + ''.join([dim_map[d] for d in ds]) + ',' - subscripts = subscripts[:-1] # remove last comma + subscripts_list = ['...' + ''.join([dim_map[d] for d in ds]) for ds + in input_core_dims] + subscripts = ','.join(subscripts_list) subscripts += '->...' + ''.join([dim_map[d] for d in output_core_dims[0]]) - result = apply_ufunc(np.einsum, subscripts, *arrays, - input_core_dims=[[]] + input_core_dims, - output_core_dims=output_core_dims, dask='allowed') + # dtype estimation is necessary for dask='parallelized' + out_dtype = dtypes.result_type(*arrays) + + # we use tensordot if available, because it is more efficient for dask + if len(broadcast_dims) == 0 and len(arrays) == 2: + axes = [[arr.get_axis_num(d) for d in arr.dims if d in dims] + for arr in arrays] + return apply_ufunc(duck_array_ops.tensordot, *arrays, dask='allowed', + input_core_dims=input_core_dims, + output_core_dims=output_core_dims, + kwargs={'axes': axes}) + + # subscripts should be passed as arg, instead of kwargs. We need + # to pass a partial function especially for parallelized computation. + func = functools.partial(np.einsum, subscripts) + result = apply_ufunc(func, *arrays, + input_core_dims=input_core_dims, + output_core_dims=output_core_dims, + dask='parallelized', output_dtypes=[out_dtype]) return result.transpose(*[d for d in all_dims if d in result.dims]) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index f6586ca33c5..3c022752174 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1972,11 +1972,6 @@ def dot(self, other, dims=None): if not isinstance(other, DataArray): raise TypeError('dot only operates on DataArrays.') - # backward compat: if there is no shared dimension, we rais an Errror - shared_dims = set(self.dims) & set(other.dims) - if len(shared_dims) == 0: - raise ValueError('no shared dimensions. Given {} and {}.'.format( - self.dims, other.dims)) return computation.dot(self, other, dims=dims) def sortby(self, variables, ascending=True): diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 97bcb422882..12d8345478f 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -13,8 +13,9 @@ _UFuncSignature, apply_ufunc, broadcast_compat_data, collect_dict_values, join_dict_keys, ordered_set_intersection, ordered_set_union, result_name, unified_dim_sizes) +from xarray.core.pycompat import dask_array_type -from . import raises_regex, requires_dask +from . import raises_regex, requires_dask, has_dask def assert_identical(a, b): @@ -744,41 +745,56 @@ def test_vectorize_dask(): assert_identical(expected, actual) -@pytest.mark.parametrize('dask', [False, True]) +@pytest.mark.parametrize('dask', [True, False]) def test_dot(dask): - a = np.arange(3 * 4).reshape(3, 4) - b = np.arange(3 * 4 * 5).reshape(3, 4, 5) - c = np.arange(5 * 6).reshape(5, 6) + pytest.mark.skipif(not has_dask, reason='test for dask.') + + a = np.arange(30 * 4).reshape(30, 4) + b = np.arange(30 * 4 * 5).reshape(30, 4, 5) + c = np.arange(5 * 60).reshape(5, 60) da_a = xr.DataArray(a, dims=['a', 'b']) da_b = xr.DataArray(b, dims=['a', 'b', 'c']) da_c = xr.DataArray(c, dims=['c', 'e']) - if dask: - da_a = da_a.chunk({'b': 2}) - da_b = da_b.chunk({'b': 2}) - da_c = da_c.chunk({'e': 3}) + da_a = da_a.chunk({'a': 3}) + da_b = da_b.chunk({'a': 3}) + da_c = da_c.chunk({'c': 3}) - actual = xr.dot(da_a, da_b, ['a', 'b']) + actual = xr.dot(da_a, da_b, dims=['a', 'b']) assert actual.dims == ('c', ) assert (actual.data == np.einsum('ij,ijk->k', a, b)).all() + assert isinstance(actual.variable.data, type(da_a.variable.data)) actual = xr.dot(da_a, da_b) assert actual.dims == ('c', ) assert (actual.data == np.einsum('ij,ijk->k', a, b)).all() + assert isinstance(actual.variable.data, type(da_a.variable.data)) - actual = xr.dot(da_a, da_b, ['b']) + if dask: + da_a = da_a.chunk({'a': 3}) + da_b = da_b.chunk({'a': 3}) + actual = xr.dot(da_a, da_b, dims=['b']) + assert actual.dims == ('a', 'c') + assert (actual.data == np.einsum('ij,ijk->ik', a, b)).all() + assert isinstance(actual.variable.data, type(da_a.variable.data)) + + pytest.skip('dot for dask array requires rechunking for core ' + 'dimensions.') + + # following requires rechunking + actual = xr.dot(da_a, da_b, dims=['b']) assert actual.dims == ('a', 'c') assert (actual.data == np.einsum('ij,ijk->ik', a, b)).all() - actual = xr.dot(da_a, da_b, 'b') + actual = xr.dot(da_a, da_b, dims='b') assert actual.dims == ('a', 'c') assert (actual.data == np.einsum('ij,ijk->ik', a, b)).all() - actual = xr.dot(da_a, da_b, 'a') + actual = xr.dot(da_a, da_b, dims='a') assert actual.dims == ('b', 'c') assert (actual.data == np.einsum('ij,ijk->jk', a, b)).all() - actual = xr.dot(da_a, da_b, 'c') + actual = xr.dot(da_a, da_b, dims='c') assert actual.dims == ('a', 'b') assert (actual.data == np.einsum('ij,ijk->ij', a, b)).all() diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 18fc27c96ab..c2ed9c288ac 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -3200,8 +3200,6 @@ def test_dot(self): da.dot(dm.to_dataset(name='dm')) with pytest.raises(TypeError): da.dot(dm.values) - with raises_regex(ValueError, 'no shared dimensions'): - da.dot(DataArray(1)) def test_binary_op_join_setting(self): dim = 'x' From 1c732a43525426610bf7b7370330d08aedf77bbd Mon Sep 17 00:00:00 2001 From: Keisuke Fujii Date: Wed, 7 Mar 2018 16:26:29 +0900 Subject: [PATCH 04/11] flake8. Add some error messages. --- doc/whats-new.rst | 3 +- xarray/core/computation.py | 58 ++++++++++++++++++-------------- xarray/tests/test_computation.py | 19 +++++++++-- 3 files changed, 50 insertions(+), 30 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 6b183758b97..acd014dcc2d 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -40,13 +40,12 @@ Enhancements - Addition of :py:func:`~xarray.dot`, equivalent to ``np.einsum``. Also, :py:func:`~xarray.DataArray.dot` now supports ``dims`` option, - which specifies the dimensions to be summed over. + which specifies the dimensions to sum over. (:issue:`1951`) - Support lazy vectorized-indexing. After this change, flexible indexing such as orthogonal/vectorized indexing, becomes possible for all the backend arrays. Also, lazy ``transpose`` is now also supported. (:issue:`1897`) By `Keisuke Fujii `_. - - Improve :py:func:`~xarray.DataArray.rolling` logic. :py:func:`~xarray.DataArrayRolling` object now supports :py:func:`~xarray.DataArrayRolling.construct` method that returns a view diff --git a/xarray/core/computation.py b/xarray/core/computation.py index c8c72fbab95..1c9ef2dab34 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -927,45 +927,48 @@ def earth_mover_distance(first_samples, def dot(*arrays, **kwargs): - """ dot(*arrays, *, dims=None) + """ dot(*arrays, dims=None) - einsum for xarray object, but providing simpler interface based on - the array dimensions. + Generalized dot product for xarray objects. Like np.einsum, but + provides a simpler interface based on array dimensions. Parameters ---------- - arrays: multiple DataArrays - arrays to compute. - dims: tuple of strings, optional - Along which dimensions to be summed over. + arrays: DataArray objects + Arrays to compute. + dims: str or tuple of strings, optional + Which dimensions to sum over. If not speciified, then all the common dimensions are summed over. Returns ------- - dot: same type to input. + dot: DataArray Examples -------- >>> da_a = xr.DataArray(np.arange(3 * 4).reshape(3, 4), dims=['a', 'b']) >>> da_b = xr.DataArray(np.arange(3 * 4 * 5).reshape(3, 4, 5), - dims=['a', 'b', 'c']) + >>> dims=['a', 'b', 'c']) >>> da_c = xr.DataArray(np.arange(5 * 6).reshape(5, 6), dims=['c', 'd']) - - >>> dot(da_a, da_b, dims=['a', 'b']).dims + >>> + >>> xr.dot(da_a, da_b, dims=['a', 'b']).dims ('c', ) - >>> dot(da_a, da_b, dims=['a']).dims + >>> xr.dot(da_a, da_b, dims=['a']).dims ('b', 'c') - >>> dot(da_a, da_b, da_c, dims=['b', 'c']).dims + >>> xr.dot(da_a, da_b, da_c, dims=['b', 'c']).dims ('a', 'd') """ from .dataarray import DataArray from .variable import Variable dims = kwargs.pop('dims', None) + if len(kwargs) > 0: + raise TypeError('Invalid keyward arguments {} are given'.format( + kwargs.keys())) - if len(arrays) < 2: - raise TypeError('More than one arrays must be provided') + if len(arrays) < 2 and dims is None: + raise TypeError('dim must be provided for one array computation.') if any(not isinstance(arr, (DataArray, Variable)) for arr in arrays): raise TypeError('Only xr.DataArray and xr.Variable are supported.') @@ -974,27 +977,30 @@ def dot(*arrays, **kwargs): dims = [dims] common_dims = set(arrays[0].dims) + all_dims = [] for arr in arrays[1:]: common_dims = common_dims.intersection(set(arr.dims)) + for arr in arrays: + all_dims += [d for d in arr.dims if d not in all_dims] + + einsum_axes = 'abcdefghijklmnopqrstuvwxyz' + dim_map = {d: einsum_axes[i] for i, d in enumerate(all_dims)} if dims is None: - dims = list(common_dims) + # find dimensions that exist in more than two arrays + whole_dims = [] + for arr in arrays: + whole_dims += [d for d in arr.dims] + dims = [d for d in all_dims if whole_dims.count(d) > 1] broadcast_dims = [d for d in common_dims if d not in dims] - input_core_dims = [] output_core_dims = [[]] - all_dims = [] - for arr in arrays: input_core_dims.append([d for d in arr.dims if d not in broadcast_dims]) output_core_dims[0] += [d for d in arr.dims if d not in output_core_dims[0] + dims + broadcast_dims] - all_dims += [d for d in arr.dims if d not in all_dims] - - einsum_axes = 'abcdefghijklmnopqrstuvwxyz' - dim_map = {d: einsum_axes[i] for i, d in enumerate(all_dims)} subscripts_list = ['...' + ''.join([dim_map[d] for d in ds]) for ds in input_core_dims] @@ -1004,7 +1010,7 @@ def dot(*arrays, **kwargs): # dtype estimation is necessary for dask='parallelized' out_dtype = dtypes.result_type(*arrays) - # we use tensordot if available, because it is more efficient for dask + # we use tensordot if possible, because it is more efficient for dask if len(broadcast_dims) == 0 and len(arrays) == 2: axes = [[arr.get_axis_num(d) for d in arr.dims if d in dims] for arr in arrays] @@ -1013,8 +1019,8 @@ def dot(*arrays, **kwargs): output_core_dims=output_core_dims, kwargs={'axes': axes}) - # subscripts should be passed as arg, instead of kwargs. We need - # to pass a partial function especially for parallelized computation. + # subscripts should be passed as arg, not as a kwargs. We need + # to construct a partial function for parallelized computation. func = functools.partial(np.einsum, subscripts) result = apply_ufunc(func, *arrays, input_core_dims=input_core_dims, diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 12d8345478f..05667cf63ec 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -13,7 +13,6 @@ _UFuncSignature, apply_ufunc, broadcast_compat_data, collect_dict_values, join_dict_keys, ordered_set_intersection, ordered_set_union, result_name, unified_dim_sizes) -from xarray.core.pycompat import dask_array_type from . import raises_regex, requires_dask, has_dask @@ -747,7 +746,8 @@ def test_vectorize_dask(): @pytest.mark.parametrize('dask', [True, False]) def test_dot(dask): - pytest.mark.skipif(not has_dask, reason='test for dask.') + if not has_dask: + pytest.skip('test for dask.') a = np.arange(30 * 4).reshape(30, 4) b = np.arange(30 * 4 * 5).reshape(30, 4, 5) @@ -802,6 +802,21 @@ def test_dot(dask): assert actual.dims == ('c', 'e') assert (actual.data == np.einsum('ij,ijk,kl->kl ', a, b, c)).all() + # default dims + actual = xr.dot(da_a, da_b, da_c) + assert actual.dims == ('e', ) + assert (actual.data == np.einsum('ij,ijk,kl->l ', a, b, c)).all() + + # 1 array summation + actual = xr.dot(da_a, dims='a') + assert actual.dims == ('b', ) + assert (actual.data == np.einsum('ij->j ', a)).all() + + with pytest.raises(TypeError): + actual = xr.dot(da_a) + with pytest.raises(TypeError): + actual = xr.dot(da_a, dims='a', invalid=None) + def test_where(): cond = xr.DataArray([True, False], dims='x') From b8d93b0819407795b375a3befc39c1aaec8337cb Mon Sep 17 00:00:00 2001 From: Keisuke Fujii Date: Wed, 7 Mar 2018 18:29:52 +0900 Subject: [PATCH 05/11] fix for sticker-ci --- xarray/core/computation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 1c9ef2dab34..9528937ba64 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -965,7 +965,7 @@ def dot(*arrays, **kwargs): dims = kwargs.pop('dims', None) if len(kwargs) > 0: raise TypeError('Invalid keyward arguments {} are given'.format( - kwargs.keys())) + list(kwargs.keys()))) if len(arrays) < 2 and dims is None: raise TypeError('dim must be provided for one array computation.') From 3278bf3d2f4c7bb05cef63d4bfd518c96367a99e Mon Sep 17 00:00:00 2001 From: Keisuke Fujii Date: Thu, 8 Mar 2018 09:09:17 +0900 Subject: [PATCH 06/11] Use counter --- xarray/core/computation.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 9528937ba64..9a2d37e2eba 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -6,6 +6,7 @@ import functools import itertools import operator +from collections import Counter import numpy as np @@ -988,10 +989,10 @@ def dot(*arrays, **kwargs): if dims is None: # find dimensions that exist in more than two arrays - whole_dims = [] + dim_counts = Counter() for arr in arrays: - whole_dims += [d for d in arr.dims] - dims = [d for d in all_dims if whole_dims.count(d) > 1] + dim_counts.update(arr.dims) + dims = [d for d, c in dim_counts.items() if c > 1] broadcast_dims = [d for d in common_dims if d not in dims] input_core_dims = [] From 1ec568317764abec2c407f287a9c7d33b017f84e Mon Sep 17 00:00:00 2001 From: keisukefujii Date: Thu, 8 Mar 2018 10:50:01 +0900 Subject: [PATCH 07/11] Always allow dims=None for xr.dot. --- xarray/core/computation.py | 8 ++------ xarray/tests/test_computation.py | 15 +++++++++++---- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 9a2d37e2eba..f8854fad857 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -961,17 +961,13 @@ def dot(*arrays, **kwargs): ('a', 'd') """ from .dataarray import DataArray - from .variable import Variable dims = kwargs.pop('dims', None) if len(kwargs) > 0: raise TypeError('Invalid keyward arguments {} are given'.format( list(kwargs.keys()))) - if len(arrays) < 2 and dims is None: - raise TypeError('dim must be provided for one array computation.') - - if any(not isinstance(arr, (DataArray, Variable)) for arr in arrays): + if any(not isinstance(arr, DataArray) for arr in arrays): raise TypeError('Only xr.DataArray and xr.Variable are supported.') if isinstance(dims, basestring): @@ -988,7 +984,7 @@ def dot(*arrays, **kwargs): dim_map = {d: einsum_axes[i] for i, d in enumerate(all_dims)} if dims is None: - # find dimensions that exist in more than two arrays + # find dimensions that occur more than one times dim_counts = Counter() for arr in arrays: dim_counts.update(arr.dims) diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 05667cf63ec..abb37ff4b1c 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -752,8 +752,10 @@ def test_dot(dask): a = np.arange(30 * 4).reshape(30, 4) b = np.arange(30 * 4 * 5).reshape(30, 4, 5) c = np.arange(5 * 60).reshape(5, 60) - da_a = xr.DataArray(a, dims=['a', 'b']) - da_b = xr.DataArray(b, dims=['a', 'b', 'c']) + da_a = xr.DataArray(a, dims=['a', 'b'], + coords={'a': np.linspace(0, 1, 30)}) + da_b = xr.DataArray(b, dims=['a', 'b', 'c'], + coords={'a': np.linspace(0, 1, 30)}) da_c = xr.DataArray(c, dims=['c', 'e']) if dask: da_a = da_a.chunk({'a': 3}) @@ -770,6 +772,11 @@ def test_dot(dask): assert (actual.data == np.einsum('ij,ijk->k', a, b)).all() assert isinstance(actual.variable.data, type(da_a.variable.data)) + # for only a single array is passed without dims argument, just return + # as is + actual = xr.dot(da_a) + assert da_a.identical(actual) + if dask: da_a = da_a.chunk({'a': 3}) da_b = da_b.chunk({'a': 3}) @@ -812,10 +819,10 @@ def test_dot(dask): assert actual.dims == ('b', ) assert (actual.data == np.einsum('ij->j ', a)).all() - with pytest.raises(TypeError): - actual = xr.dot(da_a) with pytest.raises(TypeError): actual = xr.dot(da_a, dims='a', invalid=None) + with pytest.raises(TypeError): + actual = xr.dot(da_a.to_dataset(name='da'), dims='a') def test_where(): From 789cb9691042ba150e8b8f1c4b2a45f666099bde Mon Sep 17 00:00:00 2001 From: keisukefujii Date: Thu, 8 Mar 2018 11:39:15 +0900 Subject: [PATCH 08/11] Simplify logic. More comments. --- xarray/core/computation.py | 49 +++++++++++++++++--------------- xarray/tests/test_computation.py | 7 +++++ 2 files changed, 33 insertions(+), 23 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index f8854fad857..d1e5336f72b 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -968,15 +968,19 @@ def dot(*arrays, **kwargs): list(kwargs.keys()))) if any(not isinstance(arr, DataArray) for arr in arrays): - raise TypeError('Only xr.DataArray and xr.Variable are supported.') + raise TypeError('Only xr.DataArray and xr.Variable are supported.' + 'Given {}.'.format([type(arr) for arr in arrays])) + + if len(arrays) == 0: + raise TypeError('At least one array should be given.') if isinstance(dims, basestring): - dims = [dims] + dims = (dims, ) + elif isinstance(dims, list): + dims = tuple(dims) - common_dims = set(arrays[0].dims) + common_dims = set.intersection(*[set(arr.dims) for arr in arrays]) all_dims = [] - for arr in arrays[1:]: - common_dims = common_dims.intersection(set(arr.dims)) for arr in arrays: all_dims += [d for d in arr.dims if d not in all_dims] @@ -988,24 +992,13 @@ def dot(*arrays, **kwargs): dim_counts = Counter() for arr in arrays: dim_counts.update(arr.dims) - dims = [d for d, c in dim_counts.items() if c > 1] - - broadcast_dims = [d for d in common_dims if d not in dims] - input_core_dims = [] - output_core_dims = [[]] - for arr in arrays: - input_core_dims.append([d for d in arr.dims if d not in - broadcast_dims]) - output_core_dims[0] += [d for d in arr.dims if d not in - output_core_dims[0] + dims + broadcast_dims] + dims = tuple(d for d, c in dim_counts.items() if c > 1) - subscripts_list = ['...' + ''.join([dim_map[d] for d in ds]) for ds - in input_core_dims] - subscripts = ','.join(subscripts_list) - subscripts += '->...' + ''.join([dim_map[d] for d in output_core_dims[0]]) - - # dtype estimation is necessary for dask='parallelized' - out_dtype = dtypes.result_type(*arrays) + # dimensions to be parallelized + broadcast_dims = tuple(common_dims.difference(set(dims))) + input_core_dims = [[d for d in arr.dims if d not in broadcast_dims] + for arr in arrays] + output_core_dims = [set(all_dims).difference(set(dims + broadcast_dims))] # we use tensordot if possible, because it is more efficient for dask if len(broadcast_dims) == 0 and len(arrays) == 2: @@ -1016,7 +1009,17 @@ def dot(*arrays, **kwargs): output_core_dims=output_core_dims, kwargs={'axes': axes}) - # subscripts should be passed as arg, not as a kwargs. We need + # construct einsum subscripts, such as '...abc,...ab->...c' + # Note: input_core_dims are always moved to the last position + subscripts_list = ['...' + ''.join([dim_map[d] for d in ds]) for ds + in input_core_dims] + subscripts = ','.join(subscripts_list) + subscripts += '->...' + ''.join([dim_map[d] for d in output_core_dims[0]]) + + # dtype estimation is necessary for dask='parallelized' + out_dtype = dtypes.result_type(*arrays) + + # subscripts should be passed to np.einsum as arg, not as kwargs. We need # to construct a partial function for parallelized computation. func = functools.partial(np.einsum, subscripts) result = apply_ufunc(func, *arrays, diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index abb37ff4b1c..4180f8edf59 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -809,6 +809,11 @@ def test_dot(dask): assert actual.dims == ('c', 'e') assert (actual.data == np.einsum('ij,ijk,kl->kl ', a, b, c)).all() + # should work with tuple + actual = xr.dot(da_a, da_b, dims=('c', )) + assert actual.dims == ('a', 'b') + assert (actual.data == np.einsum('ij,ijk->ij', a, b)).all() + # default dims actual = xr.dot(da_a, da_b, da_c) assert actual.dims == ('e', ) @@ -823,6 +828,8 @@ def test_dot(dask): actual = xr.dot(da_a, dims='a', invalid=None) with pytest.raises(TypeError): actual = xr.dot(da_a.to_dataset(name='da'), dims='a') + with pytest.raises(TypeError): + actual = xr.dot(dims='a') def test_where(): From a57907cf932c62967bf88fcad5f29c063612d5c3 Mon Sep 17 00:00:00 2001 From: keisukefujii Date: Thu, 8 Mar 2018 13:32:07 +0900 Subject: [PATCH 09/11] Support variable in xr.dot --- xarray/core/computation.py | 9 +++++---- xarray/tests/test_computation.py | 6 ++++++ 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index d1e5336f72b..22bc84c24b1 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -935,7 +935,7 @@ def dot(*arrays, **kwargs): Parameters ---------- - arrays: DataArray objects + arrays: DataArray (or Variable) objects Arrays to compute. dims: str or tuple of strings, optional Which dimensions to sum over. @@ -961,13 +961,14 @@ def dot(*arrays, **kwargs): ('a', 'd') """ from .dataarray import DataArray + from .variable import Variable dims = kwargs.pop('dims', None) if len(kwargs) > 0: raise TypeError('Invalid keyward arguments {} are given'.format( list(kwargs.keys()))) - if any(not isinstance(arr, DataArray) for arr in arrays): + if any(not isinstance(arr, (Variable, DataArray)) for arr in arrays): raise TypeError('Only xr.DataArray and xr.Variable are supported.' 'Given {}.'.format([type(arr) for arr in arrays])) @@ -976,8 +977,6 @@ def dot(*arrays, **kwargs): if isinstance(dims, basestring): dims = (dims, ) - elif isinstance(dims, list): - dims = tuple(dims) common_dims = set.intersection(*[set(arr.dims) for arr in arrays]) all_dims = [] @@ -994,6 +993,8 @@ def dot(*arrays, **kwargs): dim_counts.update(arr.dims) dims = tuple(d for d, c in dim_counts.items() if c > 1) + dims = tuple(dims) # make dims a tuple + # dimensions to be parallelized broadcast_dims = tuple(common_dims.difference(set(dims))) input_core_dims = [[d for d in arr.dims if d not in broadcast_dims] diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 4180f8edf59..88710e55091 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -777,6 +777,12 @@ def test_dot(dask): actual = xr.dot(da_a) assert da_a.identical(actual) + # test for variable + actual = xr.dot(da_a.variable, da_b.variable) + assert actual.dims == ('c', ) + assert (actual.data == np.einsum('ij,ijk->k', a, b)).all() + assert isinstance(actual.data, type(da_a.variable.data)) + if dask: da_a = da_a.chunk({'a': 3}) da_b = da_b.chunk({'a': 3}) From 693b242503f7fbdba6cfa865f3bed67b8a7be6ba Mon Sep 17 00:00:00 2001 From: keisukefujii Date: Thu, 8 Mar 2018 14:05:09 +0900 Subject: [PATCH 10/11] bug fix due to the undefined order of set --- xarray/core/computation.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 22bc84c24b1..2df2e3db32f 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -996,10 +996,12 @@ def dot(*arrays, **kwargs): dims = tuple(dims) # make dims a tuple # dimensions to be parallelized - broadcast_dims = tuple(common_dims.difference(set(dims))) + broadcast_dims = tuple(d for d in all_dims + if d in common_dims and d not in dims) input_core_dims = [[d for d in arr.dims if d not in broadcast_dims] for arr in arrays] - output_core_dims = [set(all_dims).difference(set(dims + broadcast_dims))] + output_core_dims = [tuple(d for d in all_dims if d not in + set(dims + broadcast_dims))] # we use tensordot if possible, because it is more efficient for dask if len(broadcast_dims) == 0 and len(arrays) == 2: From 88be319bf01c4803478e611373aa89b6631db096 Mon Sep 17 00:00:00 2001 From: keisukefujii Date: Thu, 8 Mar 2018 14:07:05 +0900 Subject: [PATCH 11/11] Remove unused casting to set --- xarray/core/computation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 2df2e3db32f..71e591befa0 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1001,7 +1001,7 @@ def dot(*arrays, **kwargs): input_core_dims = [[d for d in arr.dims if d not in broadcast_dims] for arr in arrays] output_core_dims = [tuple(d for d in all_dims if d not in - set(dims + broadcast_dims))] + dims + broadcast_dims)] # we use tensordot if possible, because it is more efficient for dask if len(broadcast_dims) == 0 and len(arrays) == 2: