Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Start renaming dims to dim #8487

Merged
merged 5 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ Breaking changes
Deprecations
~~~~~~~~~~~~

- As part of an effort to standardize the API, we're renaming the ``dims``
keyword arg to ``dim`` for the minority of functions which current use
``dims``. This started with :py:func:`xarray.dot` & :py:meth:`DataArray.dot`
and we'll gradually roll this out across all functions. The warnings are
currently ``PendingDeprecationWarning``s, which are silenced by default. We'll
convert these to ``DeprecationWarning``s in a future release.
By `Maximilian Roos <https://github.com/max-sixty>`_.

Bug fixes
~~~~~~~~~
Expand Down
16 changes: 9 additions & 7 deletions xarray/core/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from xarray.core.types import T_Alignable
from xarray.core.utils import is_dict_like, is_full_slice
from xarray.core.variable import Variable, as_compatible_data, calculate_dimensions
from xarray.util.deprecation_helpers import deprecate_dims

if TYPE_CHECKING:
from xarray.core.dataarray import DataArray
Expand Down Expand Up @@ -324,7 +325,8 @@ def assert_no_index_conflict(self) -> None:
"- they may be used to reindex data along common dimensions"
)

def _need_reindex(self, dims, cmp_indexes) -> bool:
@deprecate_dims
def _need_reindex(self, dim, cmp_indexes) -> bool:
"""Whether or not we need to reindex variables for a set of
matching indexes.

Expand All @@ -340,14 +342,14 @@ def _need_reindex(self, dims, cmp_indexes) -> bool:
return True

unindexed_dims_sizes = {}
for dim in dims:
if dim in self.unindexed_dim_sizes:
sizes = self.unindexed_dim_sizes[dim]
for d in dim:
if d in self.unindexed_dim_sizes:
sizes = self.unindexed_dim_sizes[d]
if len(sizes) > 1:
# reindex if different sizes are found for unindexed dims
return True
else:
unindexed_dims_sizes[dim] = next(iter(sizes))
unindexed_dims_sizes[d] = next(iter(sizes))

if unindexed_dims_sizes:
indexed_dims_sizes = {}
Expand All @@ -356,8 +358,8 @@ def _need_reindex(self, dims, cmp_indexes) -> bool:
for var in index_vars.values():
indexed_dims_sizes.update(var.sizes)

for dim, size in unindexed_dims_sizes.items():
if indexed_dims_sizes.get(dim, -1) != size:
for d, size in unindexed_dims_sizes.items():
if indexed_dims_sizes.get(d, -1) != size:
# reindex if unindexed dimension size doesn't match
return True

Expand Down
28 changes: 15 additions & 13 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from xarray.core.types import Dims, T_DataArray
from xarray.core.utils import is_dict_like, is_scalar
from xarray.core.variable import Variable
from xarray.util.deprecation_helpers import deprecate_dims

if TYPE_CHECKING:
from xarray.core.coordinates import Coordinates
Expand Down Expand Up @@ -1691,9 +1692,10 @@ def cross(
return c


@deprecate_dims
def dot(
*arrays,
dims: Dims = None,
dim: Dims = None,
**kwargs: Any,
):
"""Generalized dot product for xarray objects. Like ``np.einsum``, but
Expand All @@ -1703,7 +1705,7 @@ def dot(
----------
*arrays : DataArray or Variable
Arrays to compute.
dims : str, iterable of hashable, "..." or None, optional
dim : str, iterable of hashable, "..." or None, optional
Which dimensions to sum over. Ellipsis ('...') sums over all dimensions.
If not specified, then all the common dimensions are summed over.
**kwargs : dict
Expand Down Expand Up @@ -1756,18 +1758,18 @@ def dot(
[3, 4, 5]])
Dimensions without coordinates: c, d

>>> xr.dot(da_a, da_b, dims=["a", "b"])
>>> xr.dot(da_a, da_b, dim=["a", "b"])
<xarray.DataArray (c: 2)>
array([110, 125])
Dimensions without coordinates: c

>>> xr.dot(da_a, da_b, dims=["a"])
>>> xr.dot(da_a, da_b, dim=["a"])
<xarray.DataArray (b: 2, c: 2)>
array([[40, 46],
[70, 79]])
Dimensions without coordinates: b, c

>>> xr.dot(da_a, da_b, da_c, dims=["b", "c"])
>>> xr.dot(da_a, da_b, da_c, dim=["b", "c"])
<xarray.DataArray (a: 3, d: 3)>
array([[ 9, 14, 19],
[ 93, 150, 207],
Expand All @@ -1779,7 +1781,7 @@ def dot(
array([110, 125])
Dimensions without coordinates: c

>>> xr.dot(da_a, da_b, dims=...)
>>> xr.dot(da_a, da_b, dim=...)
<xarray.DataArray ()>
array(235)
"""
Expand All @@ -1803,18 +1805,18 @@ def dot(
einsum_axes = "abcdefghijklmnopqrstuvwxyz"
dim_map = {d: einsum_axes[i] for i, d in enumerate(all_dims)}

if dims is ...:
dims = all_dims
elif isinstance(dims, str):
dims = (dims,)
elif dims is None:
if dim is ...:
dim = all_dims
elif isinstance(dim, str):
dim = (dim,)
elif dim is None:
# find dimensions that occur more than one times
dim_counts: Counter = Counter()
for arr in arrays:
dim_counts.update(arr.dims)
dims = tuple(d for d, c in dim_counts.items() if c > 1)
dim = tuple(d for d, c in dim_counts.items() if c > 1)

dot_dims: set[Hashable] = set(dims)
dot_dims: set[Hashable] = set(dim)

# dimensions to be parallelized
broadcast_dims = common_dims - dot_dims
Expand Down
18 changes: 10 additions & 8 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
)
from xarray.plot.accessor import DataArrayPlotAccessor
from xarray.plot.utils import _get_units_from_attrs
from xarray.util.deprecation_helpers import _deprecate_positional_args
from xarray.util.deprecation_helpers import _deprecate_positional_args, deprecate_dims

if TYPE_CHECKING:
from typing import TypeVar, Union
Expand Down Expand Up @@ -115,14 +115,15 @@
T_XarrayOther = TypeVar("T_XarrayOther", bound=Union["DataArray", Dataset])


def _check_coords_dims(shape, coords, dims):
sizes = dict(zip(dims, shape))
@deprecate_dims
def _check_coords_dims(shape, coords, dim):
sizes = dict(zip(dim, shape))
for k, v in coords.items():
if any(d not in dims for d in v.dims):
if any(d not in dim for d in v.dims):
raise ValueError(
f"coordinate {k} has dimensions {v.dims}, but these "
"are not a subset of the DataArray "
f"dimensions {dims}"
f"dimensions {dim}"
)

for d, s in v.sizes.items():
Expand Down Expand Up @@ -4895,10 +4896,11 @@ def imag(self) -> Self:
"""
return self._replace(self.variable.imag)

@deprecate_dims
def dot(
self,
other: T_Xarray,
dims: Dims = None,
dim: Dims = None,
) -> T_Xarray:
"""Perform dot product of two DataArrays along their shared dims.

Expand All @@ -4908,7 +4910,7 @@ def dot(
----------
other : DataArray
The other array with which the dot product is performed.
dims : ..., str, Iterable of Hashable or None, optional
dim : ..., str, Iterable of Hashable or None, optional
Which dimensions to sum over. Ellipsis (`...`) sums over all dimensions.
If not specified, then all the common dimensions are summed over.

Expand Down Expand Up @@ -4947,7 +4949,7 @@ def dot(
if not isinstance(other, DataArray):
raise TypeError("dot only operates on DataArrays.")

return computation.dot(self, other, dims=dims)
return computation.dot(self, other, dim=dim)

def sortby(
self,
Expand Down
8 changes: 5 additions & 3 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
maybe_coerce_to_str,
)
from xarray.namedarray.core import NamedArray
from xarray.util.deprecation_helpers import deprecate_dims

NON_NUMPY_SUPPORTED_ARRAY_TYPES = (
indexing.ExplicitlyIndexed,
Expand Down Expand Up @@ -1541,15 +1542,16 @@ def stack(self, dimensions=None, **dimensions_kwargs):
result = result._stack_once(dims, new_dim)
return result

def _unstack_once_full(self, dims: Mapping[Any, int], old_dim: Hashable) -> Self:
@deprecate_dims
def _unstack_once_full(self, dim: Mapping[Any, int], old_dim: Hashable) -> Self:
"""
Unstacks the variable without needing an index.

Unlike `_unstack_once`, this function requires the existing dimension to
contain the full product of the new dimensions.
"""
new_dim_names = tuple(dims.keys())
new_dim_sizes = tuple(dims.values())
new_dim_names = tuple(dim.keys())
new_dim_sizes = tuple(dim.values())

if old_dim not in self.dims:
raise ValueError(f"invalid existing dimension: {old_dim}")
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def _reduce(

# `dot` does not broadcast arrays, so this avoids creating a large
# DataArray (if `weights` has additional dimensions)
return dot(da, weights, dims=dim)
return dot(da, weights, dim=dim)

def _sum_of_weights(self, da: DataArray, dim: Dims = None) -> DataArray:
"""Calculate the sum of weights, accounting for missing values"""
Expand Down
40 changes: 20 additions & 20 deletions xarray/tests/test_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1936,7 +1936,7 @@ def test_dot(use_dask: bool) -> None:
da_a = da_a.chunk({"a": 3})
da_b = da_b.chunk({"a": 3})
da_c = da_c.chunk({"c": 3})
actual = xr.dot(da_a, da_b, dims=["a", "b"])
actual = xr.dot(da_a, da_b, dim=["a", "b"])
assert actual.dims == ("c",)
assert (actual.data == np.einsum("ij,ijk->k", a, b)).all()
assert isinstance(actual.variable.data, type(da_a.variable.data))
Expand All @@ -1960,33 +1960,33 @@ def test_dot(use_dask: bool) -> None:
if use_dask:
da_a = da_a.chunk({"a": 3})
da_b = da_b.chunk({"a": 3})
actual = xr.dot(da_a, da_b, dims=["b"])
actual = xr.dot(da_a, da_b, dim=["b"])
assert actual.dims == ("a", "c")
assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all()
assert isinstance(actual.variable.data, type(da_a.variable.data))

actual = xr.dot(da_a, da_b, dims=["b"])
actual = xr.dot(da_a, da_b, dim=["b"])
assert actual.dims == ("a", "c")
assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all()

actual = xr.dot(da_a, da_b, dims="b")
actual = xr.dot(da_a, da_b, dim="b")
assert actual.dims == ("a", "c")
assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all()

actual = xr.dot(da_a, da_b, dims="a")
actual = xr.dot(da_a, da_b, dim="a")
assert actual.dims == ("b", "c")
assert (actual.data == np.einsum("ij,ijk->jk", a, b)).all()

actual = xr.dot(da_a, da_b, dims="c")
actual = xr.dot(da_a, da_b, dim="c")
assert actual.dims == ("a", "b")
assert (actual.data == np.einsum("ij,ijk->ij", a, b)).all()

actual = xr.dot(da_a, da_b, da_c, dims=["a", "b"])
actual = xr.dot(da_a, da_b, da_c, dim=["a", "b"])
assert actual.dims == ("c", "e")
assert (actual.data == np.einsum("ij,ijk,kl->kl ", a, b, c)).all()

# should work with tuple
actual = xr.dot(da_a, da_b, dims=("c",))
actual = xr.dot(da_a, da_b, dim=("c",))
assert actual.dims == ("a", "b")
assert (actual.data == np.einsum("ij,ijk->ij", a, b)).all()

Expand All @@ -1996,47 +1996,47 @@ def test_dot(use_dask: bool) -> None:
assert (actual.data == np.einsum("ij,ijk,kl->l ", a, b, c)).all()

# 1 array summation
actual = xr.dot(da_a, dims="a")
actual = xr.dot(da_a, dim="a")
assert actual.dims == ("b",)
assert (actual.data == np.einsum("ij->j ", a)).all()

# empty dim
actual = xr.dot(da_a.sel(a=[]), da_a.sel(a=[]), dims="a")
actual = xr.dot(da_a.sel(a=[]), da_a.sel(a=[]), dim="a")
assert actual.dims == ("b",)
assert (actual.data == np.zeros(actual.shape)).all()

# Ellipsis (...) sums over all dimensions
actual = xr.dot(da_a, da_b, dims=...)
actual = xr.dot(da_a, da_b, dim=...)
assert actual.dims == ()
assert (actual.data == np.einsum("ij,ijk->", a, b)).all()

actual = xr.dot(da_a, da_b, da_c, dims=...)
actual = xr.dot(da_a, da_b, da_c, dim=...)
assert actual.dims == ()
assert (actual.data == np.einsum("ij,ijk,kl-> ", a, b, c)).all()

actual = xr.dot(da_a, dims=...)
actual = xr.dot(da_a, dim=...)
assert actual.dims == ()
assert (actual.data == np.einsum("ij-> ", a)).all()

actual = xr.dot(da_a.sel(a=[]), da_a.sel(a=[]), dims=...)
actual = xr.dot(da_a.sel(a=[]), da_a.sel(a=[]), dim=...)
assert actual.dims == ()
assert (actual.data == np.zeros(actual.shape)).all()

# Invalid cases
if not use_dask:
with pytest.raises(TypeError):
xr.dot(da_a, dims="a", invalid=None)
xr.dot(da_a, dim="a", invalid=None)
with pytest.raises(TypeError):
xr.dot(da_a.to_dataset(name="da"), dims="a")
xr.dot(da_a.to_dataset(name="da"), dim="a")
with pytest.raises(TypeError):
xr.dot(dims="a")
xr.dot(dim="a")

# einsum parameters
actual = xr.dot(da_a, da_b, dims=["b"], order="C")
actual = xr.dot(da_a, da_b, dim=["b"], order="C")
assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all()
assert actual.values.flags["C_CONTIGUOUS"]
assert not actual.values.flags["F_CONTIGUOUS"]
actual = xr.dot(da_a, da_b, dims=["b"], order="F")
actual = xr.dot(da_a, da_b, dim=["b"], order="F")
assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all()
# dask converts Fortran arrays to C order when merging the final array
if not use_dask:
Expand Down Expand Up @@ -2078,7 +2078,7 @@ def test_dot_align_coords(use_dask: bool) -> None:
expected = (da_a * da_b).sum(["a", "b"])
xr.testing.assert_allclose(expected, actual)

actual = xr.dot(da_a, da_b, dims=...)
actual = xr.dot(da_a, da_b, dim=...)
expected = (da_a * da_b).sum()
xr.testing.assert_allclose(expected, actual)

Expand Down
6 changes: 3 additions & 3 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3964,13 +3964,13 @@ def test_dot(self) -> None:
assert_equal(expected3, actual3)

# Ellipsis: all dims are shared
actual4 = da.dot(da, dims=...)
actual4 = da.dot(da, dim=...)
expected4 = da.dot(da)
assert_equal(expected4, actual4)

# Ellipsis: not all dims are shared
actual5 = da.dot(dm3, dims=...)
expected5 = da.dot(dm3, dims=("j", "x", "y", "z"))
actual5 = da.dot(dm3, dim=...)
expected5 = da.dot(dm3, dim=("j", "x", "y", "z"))
assert_equal(expected5, actual5)

with pytest.raises(NotImplementedError):
Expand Down
Loading