diff --git a/doc/whats-new.rst b/doc/whats-new.rst index f9e2e9270b0..25133df8a29 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -21,12 +21,15 @@ v0.15.0 (unreleased) Breaking changes ~~~~~~~~~~~~~~~~ - - Remove ``compat`` and ``encoding`` kwargs from ``DataArray``, which have been deprecated since 0.12. (:pull:`3650`). Instead, specify the encoding when writing to disk or set the ``encoding`` attribute directly. By `Maximilian Roos `_ +- :py:func:`xarray.dot`, :py:meth:`DataArray.dot`, and the ``@`` operator now + use ``align="inner"`` (except when ``xarray.set_options(arithmetic_join="exact")``; + :issue:`3694`) by `Mathias Hauser `_. + New Features ~~~~~~~~~~~~ diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 34de5edefc5..eb9ca8c17fc 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -26,6 +26,7 @@ from . import duck_array_ops, utils from .alignment import deep_align from .merge import merge_coordinates_without_align +from .options import OPTIONS from .pycompat import dask_array_type from .utils import is_dict_like from .variable import Variable @@ -1175,6 +1176,11 @@ def dot(*arrays, dims=None, **kwargs): subscripts = ",".join(subscripts_list) subscripts += "->..." + "".join([dim_map[d] for d in output_core_dims[0]]) + join = OPTIONS["arithmetic_join"] + # using "inner" emulates `(a * b).sum()` for all joins (except "exact") + if join != "exact": + join = "inner" + # subscripts should be passed to np.einsum as arg, not as kwargs. We need # to construct a partial function for apply_ufunc to work. func = functools.partial(duck_array_ops.einsum, subscripts, **kwargs) @@ -1183,6 +1189,7 @@ def dot(*arrays, dims=None, **kwargs): *arrays, input_core_dims=input_core_dims, output_core_dims=output_core_dims, + join=join, dask="allowed", ) return result.transpose(*[d for d in all_dims if d in result.dims]) diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 1f2634cc9b0..2d373d12095 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -1043,6 +1043,60 @@ def test_dot(use_dask): pickle.loads(pickle.dumps(xr.dot(da_a))) +@pytest.mark.parametrize("use_dask", [True, False]) +def test_dot_align_coords(use_dask): + # GH 3694 + + if use_dask: + if not has_dask: + pytest.skip("test for dask.") + + a = np.arange(30 * 4).reshape(30, 4) + b = np.arange(30 * 4 * 5).reshape(30, 4, 5) + + # use partially overlapping coords + coords_a = {"a": np.arange(30), "b": np.arange(4)} + coords_b = {"a": np.arange(5, 35), "b": np.arange(1, 5)} + + da_a = xr.DataArray(a, dims=["a", "b"], coords=coords_a) + da_b = xr.DataArray(b, dims=["a", "b", "c"], coords=coords_b) + + if use_dask: + da_a = da_a.chunk({"a": 3}) + da_b = da_b.chunk({"a": 3}) + + # join="inner" is the default + actual = xr.dot(da_a, da_b) + # `dot` sums over the common dimensions of the arguments + expected = (da_a * da_b).sum(["a", "b"]) + xr.testing.assert_allclose(expected, actual) + + actual = xr.dot(da_a, da_b, dims=...) + expected = (da_a * da_b).sum() + xr.testing.assert_allclose(expected, actual) + + with xr.set_options(arithmetic_join="exact"): + with raises_regex(ValueError, "indexes along dimension"): + xr.dot(da_a, da_b) + + # NOTE: dot always uses `join="inner"` because `(a * b).sum()` yields the same for all + # join method (except "exact") + with xr.set_options(arithmetic_join="left"): + actual = xr.dot(da_a, da_b) + expected = (da_a * da_b).sum(["a", "b"]) + xr.testing.assert_allclose(expected, actual) + + with xr.set_options(arithmetic_join="right"): + actual = xr.dot(da_a, da_b) + expected = (da_a * da_b).sum(["a", "b"]) + xr.testing.assert_allclose(expected, actual) + + with xr.set_options(arithmetic_join="outer"): + actual = xr.dot(da_a, da_b) + expected = (da_a * da_b).sum(["a", "b"]) + xr.testing.assert_allclose(expected, actual) + + def test_where(): cond = xr.DataArray([True, False], dims="x") actual = xr.where(cond, 1, 0) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 786eb5007a6..962be7548bc 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -3973,6 +3973,43 @@ def test_dot(self): with pytest.raises(TypeError): da.dot(dm.values) + def test_dot_align_coords(self): + # GH 3694 + + x = np.linspace(-3, 3, 6) + y = np.linspace(-3, 3, 5) + z_a = range(4) + da_vals = np.arange(6 * 5 * 4).reshape((6, 5, 4)) + da = DataArray(da_vals, coords=[x, y, z_a], dims=["x", "y", "z"]) + + z_m = range(2, 6) + dm_vals = range(4) + dm = DataArray(dm_vals, coords=[z_m], dims=["z"]) + + with xr.set_options(arithmetic_join="exact"): + with raises_regex(ValueError, "indexes along dimension"): + da.dot(dm) + + da_aligned, dm_aligned = xr.align(da, dm, join="inner") + + # nd dot 1d + actual = da.dot(dm) + expected_vals = np.tensordot(da_aligned.values, dm_aligned.values, [2, 0]) + expected = DataArray(expected_vals, coords=[x, da_aligned.y], dims=["x", "y"]) + assert_equal(expected, actual) + + # multiple shared dims + dm_vals = np.arange(20 * 5 * 4).reshape((20, 5, 4)) + j = np.linspace(-3, 3, 20) + dm = DataArray(dm_vals, coords=[j, y, z_m], dims=["j", "y", "z"]) + da_aligned, dm_aligned = xr.align(da, dm, join="inner") + actual = da.dot(dm) + expected_vals = np.tensordot( + da_aligned.values, dm_aligned.values, axes=([1, 2], [1, 2]) + ) + expected = DataArray(expected_vals, coords=[x, j], dims=["x", "j"]) + assert_equal(expected, actual) + def test_matmul(self): # copied from above (could make a fixture) @@ -3986,6 +4023,24 @@ def test_matmul(self): expected = da.dot(da) assert_identical(result, expected) + def test_matmul_align_coords(self): + # GH 3694 + + x_a = np.arange(6) + x_b = np.arange(2, 8) + da_vals = np.arange(6) + da_a = DataArray(da_vals, coords=[x_a], dims=["x"]) + da_b = DataArray(da_vals, coords=[x_b], dims=["x"]) + + # only test arithmetic_join="inner" (=default) + result = da_a @ da_b + expected = da_a.dot(da_b) + assert_identical(result, expected) + + with xr.set_options(arithmetic_join="exact"): + with raises_regex(ValueError, "indexes along dimension"): + da_a @ da_b + def test_binary_op_propagate_indexes(self): # regression test for GH2227 self.dv["x"] = np.arange(self.dv.sizes["x"])