-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
einsum for xarray #1968
einsum for xarray #1968
Changes from 8 commits
220ebcc
4239ac6
0f472a2
c83d442
1c732a4
b8d93b0
3278bf3
1ec5683
789cb96
a57907c
693b242
88be319
b3d4768
2bd06ef
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,7 @@ Top-level functions | |
full_like | ||
zeros_like | ||
ones_like | ||
dot | ||
|
||
Dataset | ||
======= | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,13 +6,14 @@ | |
import functools | ||
import itertools | ||
import operator | ||
from collections import Counter | ||
|
||
import numpy as np | ||
|
||
from . import duck_array_ops, utils | ||
from . import duck_array_ops, utils, dtypes | ||
from .alignment import deep_align | ||
from .merge import expand_and_merge_variables | ||
from .pycompat import OrderedDict, dask_array_type | ||
from .pycompat import OrderedDict, dask_array_type, basestring | ||
from .utils import is_dict_like | ||
|
||
_DEFAULT_FROZEN_SET = frozenset() | ||
|
@@ -926,6 +927,105 @@ def earth_mover_distance(first_samples, | |
return apply_array_ufunc(func, *args, dask=dask) | ||
|
||
|
||
def dot(*arrays, **kwargs): | ||
""" dot(*arrays, dims=None) | ||
|
||
Generalized dot product for xarray objects. Like np.einsum, but | ||
provides a simpler interface based on array dimensions. | ||
|
||
Parameters | ||
---------- | ||
arrays: DataArray objects | ||
Arrays to compute. | ||
dims: str or tuple of strings, optional | ||
Which dimensions to sum over. | ||
If not speciified, then all the common dimensions are summed over. | ||
|
||
Returns | ||
------- | ||
dot: DataArray | ||
|
||
Examples | ||
-------- | ||
|
||
>>> da_a = xr.DataArray(np.arange(3 * 4).reshape(3, 4), dims=['a', 'b']) | ||
>>> da_b = xr.DataArray(np.arange(3 * 4 * 5).reshape(3, 4, 5), | ||
>>> dims=['a', 'b', 'c']) | ||
>>> da_c = xr.DataArray(np.arange(5 * 6).reshape(5, 6), dims=['c', 'd']) | ||
>>> | ||
>>> xr.dot(da_a, da_b, dims=['a', 'b']).dims | ||
('c', ) | ||
>>> xr.dot(da_a, da_b, dims=['a']).dims | ||
('b', 'c') | ||
>>> xr.dot(da_a, da_b, da_c, dims=['b', 'c']).dims | ||
('a', 'd') | ||
""" | ||
from .dataarray import DataArray | ||
|
||
dims = kwargs.pop('dims', None) | ||
if len(kwargs) > 0: | ||
raise TypeError('Invalid keyward arguments {} are given'.format( | ||
list(kwargs.keys()))) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What happens if you write |
||
if any(not isinstance(arr, DataArray) for arr in arrays): | ||
raise TypeError('Only xr.DataArray and xr.Variable are supported.') | ||
|
||
if isinstance(dims, basestring): | ||
dims = [dims] | ||
|
||
common_dims = set(arrays[0].dims) | ||
all_dims = [] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would it work to make There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I want to keep the occurrence order in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, sounds good. |
||
for arr in arrays[1:]: | ||
common_dims = common_dims.intersection(set(arr.dims)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a slightly different choice of default dimensions than
Should we switch this behavior to match There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It might be slightly more efficient to construct e.g., |
||
for arr in arrays: | ||
all_dims += [d for d in arr.dims if d not in all_dims] | ||
|
||
einsum_axes = 'abcdefghijklmnopqrstuvwxyz' | ||
dim_map = {d: einsum_axes[i] for i, d in enumerate(all_dims)} | ||
|
||
if dims is None: | ||
# find dimensions that occur more than one times | ||
dim_counts = Counter() | ||
for arr in arrays: | ||
dim_counts.update(arr.dims) | ||
dims = [d for d, c in dim_counts.items() if c > 1] | ||
|
||
broadcast_dims = [d for d in common_dims if d not in dims] | ||
input_core_dims = [] | ||
output_core_dims = [[]] | ||
for arr in arrays: | ||
input_core_dims.append([d for d in arr.dims if d not in | ||
broadcast_dims]) | ||
output_core_dims[0] += [d for d in arr.dims if d not in | ||
output_core_dims[0] + dims + broadcast_dims] | ||
|
||
subscripts_list = ['...' + ''.join([dim_map[d] for d in ds]) for ds | ||
in input_core_dims] | ||
subscripts = ','.join(subscripts_list) | ||
subscripts += '->...' + ''.join([dim_map[d] for d in output_core_dims[0]]) | ||
|
||
# dtype estimation is necessary for dask='parallelized' | ||
out_dtype = dtypes.result_type(*arrays) | ||
|
||
# we use tensordot if possible, because it is more efficient for dask | ||
if len(broadcast_dims) == 0 and len(arrays) == 2: | ||
axes = [[arr.get_axis_num(d) for d in arr.dims if d in dims] | ||
for arr in arrays] | ||
return apply_ufunc(duck_array_ops.tensordot, *arrays, dask='allowed', | ||
input_core_dims=input_core_dims, | ||
output_core_dims=output_core_dims, | ||
kwargs={'axes': axes}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks. I added a path for tensordot, which dask can compute more efficiently. |
||
|
||
# subscripts should be passed as arg, not as a kwargs. We need | ||
# to construct a partial function for parallelized computation. | ||
func = functools.partial(np.einsum, subscripts) | ||
result = apply_ufunc(func, *arrays, | ||
input_core_dims=input_core_dims, | ||
output_core_dims=output_core_dims, | ||
dask='parallelized', output_dtypes=[out_dtype]) | ||
return result.transpose(*[d for d in all_dims if d in result.dims]) | ||
|
||
|
||
def where(cond, x, y): | ||
"""Return elements from `x` or `y` depending on `cond`. | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dot(*arrays, *, dims=None)
is the way to write this with Python 3's keyword only arguments.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we would keep this as
dot(*arrays, **kwargs)
as we did not yet drop python 2 support?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was confused.
def dot(*arrays, *, dims=None)
is not valid syntax in Python 3, either. (There can only be one single*
)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PEP3102 says we python 3 supports the form
def dot(*arrays, dim=None)
.