Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/align in dot #3699

Merged
merged 7 commits into from
Jan 20, 2020
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ v0.15.0 (unreleased)

Breaking changes
~~~~~~~~~~~~~~~~

- :py:func:`xarray.dot`, :py:meth:`DataArray.dot`, and the ``@`` operator now
use ``align="inner"`` (except when ``xarray.set_options(arithmetic_join="exact")``;
:issue:`3694`) by `Mathias Hauser <https://github.com/mathause>`_.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It stops raising an error when your data does not align, no change if it does; so I assume it is unlikely for anyone to rely on this raising?


New Features
~~~~~~~~~~~~
Expand Down
7 changes: 7 additions & 0 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from . import duck_array_ops, utils
from .alignment import deep_align
from .merge import merge_coordinates_without_align
from .options import OPTIONS
from .pycompat import dask_array_type
from .utils import is_dict_like
from .variable import Variable
Expand Down Expand Up @@ -1175,6 +1176,11 @@ def dot(*arrays, dims=None, **kwargs):
subscripts = ",".join(subscripts_list)
subscripts += "->..." + "".join([dim_map[d] for d in output_core_dims[0]])

join = OPTIONS["arithmetic_join"]
# using "inner" emulates `(a * b).sum()` for all joins (except "exact")
if join != "exact":
join = "inner"

# subscripts should be passed to np.einsum as arg, not as kwargs. We need
# to construct a partial function for apply_ufunc to work.
func = functools.partial(duck_array_ops.einsum, subscripts, **kwargs)
Expand All @@ -1183,6 +1189,7 @@ def dot(*arrays, dims=None, **kwargs):
*arrays,
input_core_dims=input_core_dims,
output_core_dims=output_core_dims,
join=join,
dask="allowed",
)
return result.transpose(*[d for d in all_dims if d in result.dims])
Expand Down
54 changes: 54 additions & 0 deletions xarray/tests/test_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,6 +1043,60 @@ def test_dot(use_dask):
pickle.loads(pickle.dumps(xr.dot(da_a)))


@pytest.mark.parametrize("use_dask", [True, False])
def test_dot_align_coords(use_dask):
# GH 3694

if use_dask:
if not has_dask:
pytest.skip("test for dask.")

a = np.arange(30 * 4).reshape(30, 4)
b = np.arange(30 * 4 * 5).reshape(30, 4, 5)

# use partially overlapping coords
coords_a = {"a": np.arange(30), "b": np.arange(4)}
coords_b = {"a": np.arange(5, 35), "b": np.arange(1, 5)}

da_a = xr.DataArray(a, dims=["a", "b"], coords=coords_a)
da_b = xr.DataArray(b, dims=["a", "b", "c"], coords=coords_b)

if use_dask:
da_a = da_a.chunk({"a": 3})
da_b = da_b.chunk({"a": 3})

# join="inner" is the default
actual = xr.dot(da_a, da_b)
# `dot` sums over the common dimensions of the arguments
expected = (da_a * da_b).sum(["a", "b"])
xr.testing.assert_allclose(expected, actual)

actual = xr.dot(da_a, da_b, dims=...)
expected = (da_a * da_b).sum()
xr.testing.assert_allclose(expected, actual)

with xr.set_options(arithmetic_join="exact"):
with raises_regex(ValueError, "indexes along dimension"):
xr.dot(da_a, da_b)

# NOTE: dot always uses `join="inner"` because `(a * b).sum()` yields the same for all
# join method (except "exact")
with xr.set_options(arithmetic_join="left"):
actual = xr.dot(da_a, da_b)
expected = (da_a * da_b).sum(["a", "b"])
xr.testing.assert_allclose(expected, actual)

with xr.set_options(arithmetic_join="right"):
actual = xr.dot(da_a, da_b)
expected = (da_a * da_b).sum(["a", "b"])
xr.testing.assert_allclose(expected, actual)

with xr.set_options(arithmetic_join="outer"):
actual = xr.dot(da_a, da_b)
expected = (da_a * da_b).sum(["a", "b"])
xr.testing.assert_allclose(expected, actual)


def test_where():
cond = xr.DataArray([True, False], dims="x")
actual = xr.where(cond, 1, 0)
Expand Down
55 changes: 55 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3973,6 +3973,43 @@ def test_dot(self):
with pytest.raises(TypeError):
da.dot(dm.values)

def test_dot_align_coords(self):
# GH 3694

x = np.linspace(-3, 3, 6)
y = np.linspace(-3, 3, 5)
z_a = range(4)
da_vals = np.arange(6 * 5 * 4).reshape((6, 5, 4))
da = DataArray(da_vals, coords=[x, y, z_a], dims=["x", "y", "z"])

z_m = range(2, 6)
dm_vals = range(4)
dm = DataArray(dm_vals, coords=[z_m], dims=["z"])

with xr.set_options(arithmetic_join="exact"):
with raises_regex(ValueError, "indexes along dimension"):
da.dot(dm)

da_aligned, dm_aligned = xr.align(da, dm, join="inner")

# nd dot 1d
actual = da.dot(dm)
expected_vals = np.tensordot(da_aligned.values, dm_aligned.values, [2, 0])
expected = DataArray(expected_vals, coords=[x, da_aligned.y], dims=["x", "y"])
assert_equal(expected, actual)

# multiple shared dims
dm_vals = np.arange(20 * 5 * 4).reshape((20, 5, 4))
j = np.linspace(-3, 3, 20)
dm = DataArray(dm_vals, coords=[j, y, z_m], dims=["j", "y", "z"])
da_aligned, dm_aligned = xr.align(da, dm, join="inner")
actual = da.dot(dm)
expected_vals = np.tensordot(
da_aligned.values, dm_aligned.values, axes=([1, 2], [1, 2])
)
expected = DataArray(expected_vals, coords=[x, j], dims=["x", "j"])
assert_equal(expected, actual)

def test_matmul(self):

# copied from above (could make a fixture)
Expand All @@ -3986,6 +4023,24 @@ def test_matmul(self):
expected = da.dot(da)
assert_identical(result, expected)

def test_matmul_align_coords(self):
# GH 3694

x_a = np.arange(6)
x_b = np.arange(2, 8)
da_vals = np.arange(6)
da_a = DataArray(da_vals, coords=[x_a], dims=["x"])
da_b = DataArray(da_vals, coords=[x_b], dims=["x"])

# only test arithmetic_join="inner" (=default)
result = da_a @ da_b
expected = da_a.dot(da_b)
assert_identical(result, expected)

with xr.set_options(arithmetic_join="exact"):
with raises_regex(ValueError, "indexes along dimension"):
da_a @ da_b

def test_binary_op_propagate_indexes(self):
# regression test for GH2227
self.dv["x"] = np.arange(self.dv.sizes["x"])
Expand Down