Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Start renaming dims to dim #8487

Merged
merged 5 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ Breaking changes
Deprecations
~~~~~~~~~~~~

- As part of an effort to standardize the API, we're renaming the ``dims``
keyword arg to ``dim`` for the minority of functions which current use
``dims``. This started with :py:func:`xarray.dot` & :py:meth:`DataArray.dot`
and we'll gradually roll this out across all functions. The warnings are
currently ``PendingDeprecationWarning``, which are silenced by default. We'll
convert these to ``DeprecationWarning`` in a future release.
By `Maximilian Roos <https://github.com/max-sixty>`_.

Bug fixes
~~~~~~~~~
Expand Down
14 changes: 7 additions & 7 deletions xarray/core/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def assert_no_index_conflict(self) -> None:
"- they may be used to reindex data along common dimensions"
)

def _need_reindex(self, dims, cmp_indexes) -> bool:
def _need_reindex(self, dim, cmp_indexes) -> bool:
max-sixty marked this conversation as resolved.
Show resolved Hide resolved
"""Whether or not we need to reindex variables for a set of
matching indexes.

Expand All @@ -340,14 +340,14 @@ def _need_reindex(self, dims, cmp_indexes) -> bool:
return True

unindexed_dims_sizes = {}
for dim in dims:
if dim in self.unindexed_dim_sizes:
sizes = self.unindexed_dim_sizes[dim]
for d in dim:
if d in self.unindexed_dim_sizes:
sizes = self.unindexed_dim_sizes[d]
if len(sizes) > 1:
# reindex if different sizes are found for unindexed dims
return True
else:
unindexed_dims_sizes[dim] = next(iter(sizes))
unindexed_dims_sizes[d] = next(iter(sizes))

if unindexed_dims_sizes:
indexed_dims_sizes = {}
Expand All @@ -356,8 +356,8 @@ def _need_reindex(self, dims, cmp_indexes) -> bool:
for var in index_vars.values():
indexed_dims_sizes.update(var.sizes)

for dim, size in unindexed_dims_sizes.items():
if indexed_dims_sizes.get(dim, -1) != size:
for d, size in unindexed_dims_sizes.items():
if indexed_dims_sizes.get(d, -1) != size:
# reindex if unindexed dimension size doesn't match
return True

Expand Down
28 changes: 15 additions & 13 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from xarray.core.types import Dims, T_DataArray
from xarray.core.utils import is_dict_like, is_scalar
from xarray.core.variable import Variable
from xarray.util.deprecation_helpers import deprecate_dims

if TYPE_CHECKING:
from xarray.core.coordinates import Coordinates
Expand Down Expand Up @@ -1691,9 +1692,10 @@ def cross(
return c


@deprecate_dims
def dot(
*arrays,
dims: Dims = None,
dim: Dims = None,
**kwargs: Any,
):
"""Generalized dot product for xarray objects. Like ``np.einsum``, but
Expand All @@ -1703,7 +1705,7 @@ def dot(
----------
*arrays : DataArray or Variable
Arrays to compute.
dims : str, iterable of hashable, "..." or None, optional
dim : str, iterable of hashable, "..." or None, optional
Which dimensions to sum over. Ellipsis ('...') sums over all dimensions.
If not specified, then all the common dimensions are summed over.
**kwargs : dict
Expand Down Expand Up @@ -1756,18 +1758,18 @@ def dot(
[3, 4, 5]])
Dimensions without coordinates: c, d

>>> xr.dot(da_a, da_b, dims=["a", "b"])
>>> xr.dot(da_a, da_b, dim=["a", "b"])
<xarray.DataArray (c: 2)>
array([110, 125])
Dimensions without coordinates: c

>>> xr.dot(da_a, da_b, dims=["a"])
>>> xr.dot(da_a, da_b, dim=["a"])
<xarray.DataArray (b: 2, c: 2)>
array([[40, 46],
[70, 79]])
Dimensions without coordinates: b, c

>>> xr.dot(da_a, da_b, da_c, dims=["b", "c"])
>>> xr.dot(da_a, da_b, da_c, dim=["b", "c"])
<xarray.DataArray (a: 3, d: 3)>
array([[ 9, 14, 19],
[ 93, 150, 207],
Expand All @@ -1779,7 +1781,7 @@ def dot(
array([110, 125])
Dimensions without coordinates: c

>>> xr.dot(da_a, da_b, dims=...)
>>> xr.dot(da_a, da_b, dim=...)
<xarray.DataArray ()>
array(235)
"""
Expand All @@ -1803,18 +1805,18 @@ def dot(
einsum_axes = "abcdefghijklmnopqrstuvwxyz"
dim_map = {d: einsum_axes[i] for i, d in enumerate(all_dims)}

if dims is ...:
dims = all_dims
elif isinstance(dims, str):
dims = (dims,)
elif dims is None:
if dim is ...:
dim = all_dims
elif isinstance(dim, str):
dim = (dim,)
elif dim is None:
# find dimensions that occur more than one times
dim_counts: Counter = Counter()
for arr in arrays:
dim_counts.update(arr.dims)
dims = tuple(d for d, c in dim_counts.items() if c > 1)
dim = tuple(d for d, c in dim_counts.items() if c > 1)

dot_dims: set[Hashable] = set(dims)
dot_dims: set[Hashable] = set(dim)

# dimensions to be parallelized
broadcast_dims = common_dims - dot_dims
Expand Down
17 changes: 9 additions & 8 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
)
from xarray.plot.accessor import DataArrayPlotAccessor
from xarray.plot.utils import _get_units_from_attrs
from xarray.util.deprecation_helpers import _deprecate_positional_args
from xarray.util.deprecation_helpers import _deprecate_positional_args, deprecate_dims

if TYPE_CHECKING:
from typing import TypeVar, Union
Expand Down Expand Up @@ -115,14 +115,14 @@
T_XarrayOther = TypeVar("T_XarrayOther", bound=Union["DataArray", Dataset])


def _check_coords_dims(shape, coords, dims):
sizes = dict(zip(dims, shape))
def _check_coords_dims(shape, coords, dim):
max-sixty marked this conversation as resolved.
Show resolved Hide resolved
sizes = dict(zip(dim, shape))
for k, v in coords.items():
if any(d not in dims for d in v.dims):
if any(d not in dim for d in v.dims):
raise ValueError(
f"coordinate {k} has dimensions {v.dims}, but these "
"are not a subset of the DataArray "
f"dimensions {dims}"
f"dimensions {dim}"
)

for d, s in v.sizes.items():
Expand Down Expand Up @@ -4895,10 +4895,11 @@ def imag(self) -> Self:
"""
return self._replace(self.variable.imag)

@deprecate_dims
def dot(
self,
other: T_Xarray,
dims: Dims = None,
dim: Dims = None,
) -> T_Xarray:
"""Perform dot product of two DataArrays along their shared dims.

Expand All @@ -4908,7 +4909,7 @@ def dot(
----------
other : DataArray
The other array with which the dot product is performed.
dims : ..., str, Iterable of Hashable or None, optional
dim : ..., str, Iterable of Hashable or None, optional
Which dimensions to sum over. Ellipsis (`...`) sums over all dimensions.
If not specified, then all the common dimensions are summed over.

Expand Down Expand Up @@ -4947,7 +4948,7 @@ def dot(
if not isinstance(other, DataArray):
raise TypeError("dot only operates on DataArrays.")

return computation.dot(self, other, dims=dims)
return computation.dot(self, other, dim=dim)

def sortby(
self,
Expand Down
6 changes: 3 additions & 3 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1541,15 +1541,15 @@ def stack(self, dimensions=None, **dimensions_kwargs):
result = result._stack_once(dims, new_dim)
return result

def _unstack_once_full(self, dims: Mapping[Any, int], old_dim: Hashable) -> Self:
def _unstack_once_full(self, dim: Mapping[Any, int], old_dim: Hashable) -> Self:
max-sixty marked this conversation as resolved.
Show resolved Hide resolved
"""
Unstacks the variable without needing an index.

Unlike `_unstack_once`, this function requires the existing dimension to
contain the full product of the new dimensions.
"""
new_dim_names = tuple(dims.keys())
new_dim_sizes = tuple(dims.values())
new_dim_names = tuple(dim.keys())
new_dim_sizes = tuple(dim.values())

if old_dim not in self.dims:
raise ValueError(f"invalid existing dimension: {old_dim}")
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def _reduce(

# `dot` does not broadcast arrays, so this avoids creating a large
# DataArray (if `weights` has additional dimensions)
return dot(da, weights, dims=dim)
return dot(da, weights, dim=dim)

def _sum_of_weights(self, da: DataArray, dim: Dims = None) -> DataArray:
"""Calculate the sum of weights, accounting for missing values"""
Expand Down
40 changes: 20 additions & 20 deletions xarray/tests/test_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1936,7 +1936,7 @@ def test_dot(use_dask: bool) -> None:
da_a = da_a.chunk({"a": 3})
da_b = da_b.chunk({"a": 3})
da_c = da_c.chunk({"c": 3})
actual = xr.dot(da_a, da_b, dims=["a", "b"])
actual = xr.dot(da_a, da_b, dim=["a", "b"])
assert actual.dims == ("c",)
assert (actual.data == np.einsum("ij,ijk->k", a, b)).all()
assert isinstance(actual.variable.data, type(da_a.variable.data))
Expand All @@ -1960,33 +1960,33 @@ def test_dot(use_dask: bool) -> None:
if use_dask:
da_a = da_a.chunk({"a": 3})
da_b = da_b.chunk({"a": 3})
actual = xr.dot(da_a, da_b, dims=["b"])
actual = xr.dot(da_a, da_b, dim=["b"])
assert actual.dims == ("a", "c")
assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all()
assert isinstance(actual.variable.data, type(da_a.variable.data))

actual = xr.dot(da_a, da_b, dims=["b"])
actual = xr.dot(da_a, da_b, dim=["b"])
assert actual.dims == ("a", "c")
assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all()

actual = xr.dot(da_a, da_b, dims="b")
actual = xr.dot(da_a, da_b, dim="b")
assert actual.dims == ("a", "c")
assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all()

actual = xr.dot(da_a, da_b, dims="a")
actual = xr.dot(da_a, da_b, dim="a")
assert actual.dims == ("b", "c")
assert (actual.data == np.einsum("ij,ijk->jk", a, b)).all()

actual = xr.dot(da_a, da_b, dims="c")
actual = xr.dot(da_a, da_b, dim="c")
assert actual.dims == ("a", "b")
assert (actual.data == np.einsum("ij,ijk->ij", a, b)).all()

actual = xr.dot(da_a, da_b, da_c, dims=["a", "b"])
actual = xr.dot(da_a, da_b, da_c, dim=["a", "b"])
assert actual.dims == ("c", "e")
assert (actual.data == np.einsum("ij,ijk,kl->kl ", a, b, c)).all()

# should work with tuple
actual = xr.dot(da_a, da_b, dims=("c",))
actual = xr.dot(da_a, da_b, dim=("c",))
assert actual.dims == ("a", "b")
assert (actual.data == np.einsum("ij,ijk->ij", a, b)).all()

Expand All @@ -1996,47 +1996,47 @@ def test_dot(use_dask: bool) -> None:
assert (actual.data == np.einsum("ij,ijk,kl->l ", a, b, c)).all()

# 1 array summation
actual = xr.dot(da_a, dims="a")
actual = xr.dot(da_a, dim="a")
assert actual.dims == ("b",)
assert (actual.data == np.einsum("ij->j ", a)).all()

# empty dim
actual = xr.dot(da_a.sel(a=[]), da_a.sel(a=[]), dims="a")
actual = xr.dot(da_a.sel(a=[]), da_a.sel(a=[]), dim="a")
assert actual.dims == ("b",)
assert (actual.data == np.zeros(actual.shape)).all()

# Ellipsis (...) sums over all dimensions
actual = xr.dot(da_a, da_b, dims=...)
actual = xr.dot(da_a, da_b, dim=...)
assert actual.dims == ()
assert (actual.data == np.einsum("ij,ijk->", a, b)).all()

actual = xr.dot(da_a, da_b, da_c, dims=...)
actual = xr.dot(da_a, da_b, da_c, dim=...)
assert actual.dims == ()
assert (actual.data == np.einsum("ij,ijk,kl-> ", a, b, c)).all()

actual = xr.dot(da_a, dims=...)
actual = xr.dot(da_a, dim=...)
assert actual.dims == ()
assert (actual.data == np.einsum("ij-> ", a)).all()

actual = xr.dot(da_a.sel(a=[]), da_a.sel(a=[]), dims=...)
actual = xr.dot(da_a.sel(a=[]), da_a.sel(a=[]), dim=...)
assert actual.dims == ()
assert (actual.data == np.zeros(actual.shape)).all()

# Invalid cases
if not use_dask:
with pytest.raises(TypeError):
xr.dot(da_a, dims="a", invalid=None)
xr.dot(da_a, dim="a", invalid=None)
with pytest.raises(TypeError):
xr.dot(da_a.to_dataset(name="da"), dims="a")
xr.dot(da_a.to_dataset(name="da"), dim="a")
with pytest.raises(TypeError):
xr.dot(dims="a")
xr.dot(dim="a")

# einsum parameters
actual = xr.dot(da_a, da_b, dims=["b"], order="C")
actual = xr.dot(da_a, da_b, dim=["b"], order="C")
assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all()
assert actual.values.flags["C_CONTIGUOUS"]
assert not actual.values.flags["F_CONTIGUOUS"]
actual = xr.dot(da_a, da_b, dims=["b"], order="F")
actual = xr.dot(da_a, da_b, dim=["b"], order="F")
assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all()
# dask converts Fortran arrays to C order when merging the final array
if not use_dask:
Expand Down Expand Up @@ -2078,7 +2078,7 @@ def test_dot_align_coords(use_dask: bool) -> None:
expected = (da_a * da_b).sum(["a", "b"])
xr.testing.assert_allclose(expected, actual)

actual = xr.dot(da_a, da_b, dims=...)
actual = xr.dot(da_a, da_b, dim=...)
expected = (da_a * da_b).sum()
xr.testing.assert_allclose(expected, actual)

Expand Down
6 changes: 3 additions & 3 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3964,13 +3964,13 @@ def test_dot(self) -> None:
assert_equal(expected3, actual3)

# Ellipsis: all dims are shared
actual4 = da.dot(da, dims=...)
actual4 = da.dot(da, dim=...)
expected4 = da.dot(da)
assert_equal(expected4, actual4)

# Ellipsis: not all dims are shared
actual5 = da.dot(dm3, dims=...)
expected5 = da.dot(dm3, dims=("j", "x", "y", "z"))
actual5 = da.dot(dm3, dim=...)
expected5 = da.dot(dm3, dim=("j", "x", "y", "z"))
assert_equal(expected5, actual5)

with pytest.raises(NotImplementedError):
Expand Down
27 changes: 27 additions & 0 deletions xarray/util/deprecation_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
from functools import wraps
from typing import Callable, TypeVar

from xarray.core.utils import emit_user_level_warning

T = TypeVar("T", bound=Callable)

POSITIONAL_OR_KEYWORD = inspect.Parameter.POSITIONAL_OR_KEYWORD
Expand Down Expand Up @@ -115,3 +117,28 @@ def inner(*args, **kwargs):
return inner

return _decorator


def deprecate_dims(func: T) -> T:
"""
For functions that previously took `dims` as a kwarg, and have now transitioned to
`dim`. This decorator will issue a warning if `dims` is passed while forwarding it
to `dim`.
"""

@wraps(func)
def wrapper(*args, **kwargs):
if "dims" in kwargs:
emit_user_level_warning(
"The `dims` argument has been renamed to `dim`, and will be removed "
"in the future. This renaming is taking place throughout xarray over the "
"next few releases.",
# Upgrade to `DeprecationWarning` in the future, when the renaming is complete.
PendingDeprecationWarning,
)
kwargs["dim"] = kwargs.pop("dims")
return func(*args, **kwargs)

# We're quite confident we're just returning `T` from this function, so it's fine to ignore typing
# within the function.
return wrapper # type: ignore
Loading