Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate ds.dims returning dict #8500

Merged
merged 26 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
19708fa
raise FutureWarning
TomNicholas Dec 1, 2023
9747fe7
change some internal instances of ds.dims -> ds.sizes
TomNicholas Dec 1, 2023
0362233
improve clarity of which unexpected errors were raised
TomNicholas Dec 1, 2023
99d7e16
whatsnew
TomNicholas Dec 1, 2023
3e05777
return a class which warns if treated like a Mapping
TomNicholas Dec 1, 2023
5a22dca
fix failing tests
TomNicholas Dec 1, 2023
72eb62d
avoid some warnings in the docs
TomNicholas Dec 1, 2023
e911155
silence warning caused by #8491
TomNicholas Dec 1, 2023
382601f
Merge branch 'main' into deprecate_ds_dims_returning_dict
TomNicholas Dec 1, 2023
2c9cb7d
fix another warning
TomNicholas Dec 1, 2023
d862f89
typing of .get
TomNicholas Dec 1, 2023
225414c
fix various uses of ds.dims in tests
TomNicholas Dec 1, 2023
3391c83
fix some warnings
TomNicholas Dec 1, 2023
93f8caa
add test that FutureWarnings are correctly raised
TomNicholas Dec 1, 2023
c0aa491
more fixes to avoid warnings
TomNicholas Dec 1, 2023
e4d45ac
update tests to avoid warnings
TomNicholas Dec 1, 2023
abc4a20
yet more fixes to avoid warnings
TomNicholas Dec 2, 2023
ef11e77
also warn in groupby.dims
TomNicholas Dec 2, 2023
965673b
change groupby tests to match
TomNicholas Dec 2, 2023
27f7642
update whatsnew to include groupby deprecation
TomNicholas Dec 2, 2023
3406cc7
filter warning when we actually test ds.dims
TomNicholas Dec 2, 2023
726f390
remove error I used for debugging
TomNicholas Dec 2, 2023
d1209ea
Merge branch 'main' into deprecate_ds_dims_returning_dict
TomNicholas Dec 2, 2023
0b6466e
Merge branch 'main' into deprecate_ds_dims_returning_dict
TomNicholas Dec 3, 2023
b473d24
Merge branch 'main' into deprecate_ds_dims_returning_dict
TomNicholas Dec 5, 2023
ddbab72
Merge branch 'main' into deprecate_ds_dims_returning_dict
dcherian Dec 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/gallery/plot_cartopy_facetgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
transform=ccrs.PlateCarree(), # the data's projection
col="time",
col_wrap=1, # multiplot settings
aspect=ds.dims["lon"] / ds.dims["lat"], # for a sensible figsize
aspect=ds.sizes["lon"] / ds.sizes["lat"], # for a sensible figsize
subplot_kws={"projection": map_proj}, # the plot's projection
)

Expand Down
4 changes: 2 additions & 2 deletions doc/user-guide/interpolation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,8 @@ Let's see how :py:meth:`~xarray.DataArray.interp` works on real data.
axes[0].set_title("Raw data")

# Interpolated data
new_lon = np.linspace(ds.lon[0], ds.lon[-1], ds.dims["lon"] * 4)
new_lat = np.linspace(ds.lat[0], ds.lat[-1], ds.dims["lat"] * 4)
new_lon = np.linspace(ds.lon[0], ds.lon[-1], ds.sizes["lon"] * 4)
new_lat = np.linspace(ds.lat[0], ds.lat[-1], ds.sizes["lat"] * 4)
dsi = ds.interp(lat=new_lat, lon=new_lon)
dsi.air.plot(ax=axes[1])
@savefig interpolation_sample3.png width=8in
Expand Down
9 changes: 4 additions & 5 deletions doc/user-guide/terminology.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ complete examples, please consult the relevant documentation.*
all but one of these degrees of freedom is fixed. We can think of each
dimension axis as having a name, for example the "x dimension". In
xarray, a ``DataArray`` object's *dimensions* are its named dimension
axes, and the name of the ``i``-th dimension is ``arr.dims[i]``. If an
array is created without dimension names, the default dimension names are
``dim_0``, ``dim_1``, and so forth.
axes ``da.dims``, and the name of the ``i``-th dimension is ``da.dims[i]``.
If an array is created without specifying dimension names, the default dimension
names will be ``dim_0``, ``dim_1``, and so forth.

Coordinate
An array that labels a dimension or set of dimensions of another
Expand All @@ -61,8 +61,7 @@ complete examples, please consult the relevant documentation.*
``arr.coords[x]``. A ``DataArray`` can have more coordinates than
dimensions because a single dimension can be labeled by multiple
coordinate arrays. However, only one coordinate array can be a assigned
as a particular dimension's dimension coordinate array. As a
consequence, ``len(arr.dims) <= len(arr.coords)`` in general.
as a particular dimension's dimension coordinate array.

Dimension coordinate
A one-dimensional coordinate array assigned to ``arr`` with both a name
Expand Down
6 changes: 6 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ Deprecations
currently ``PendingDeprecationWarning``, which are silenced by default. We'll
convert these to ``DeprecationWarning`` in a future release.
By `Maximilian Roos <https://github.com/max-sixty>`_.
- Raise a ``FutureWarning`` warning that the type of :py:meth:`Dataset.dims` will be changed
from a mapping of dimension names to lengths to a set of dimension names.
This is to increase consistency with :py:meth:`DataArray.dims`.
To access a mapping of dimension names to lengths please use :py:meth:`Dataset.sizes`.
(:issue:`8496`, :pull:`8500`)
By `Tom Nicholas <https://github.com/TomNicholas>`_.

Bug fixes
~~~~~~~~~
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1169,7 +1169,7 @@ def _dataset_indexer(dim: Hashable) -> DataArray:
cond_wdim = cond.drop_vars(
var for var in cond if dim not in cond[var].dims
)
keepany = cond_wdim.any(dim=(d for d in cond.dims.keys() if d != dim))
keepany = cond_wdim.any(dim=(d for d in cond.dims if d != dim))
return keepany.to_dataarray().any("variable")

_get_indexer = (
Expand Down
36 changes: 19 additions & 17 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
from xarray.core.utils import (
Default,
Frozen,
FrozenMappingWarningOnValuesAccess,
HybridMappingProxy,
OrderedSet,
_default,
Expand Down Expand Up @@ -777,14 +778,15 @@ def dims(self) -> Frozen[Hashable, int]:

Note that type of this object differs from `DataArray.dims`.
See `Dataset.sizes` and `DataArray.sizes` for consistently named
properties.
properties. This property will be changed to return a type more consistent with
`DataArray.dims` in the future, i.e. a set of dimension names.

See Also
--------
Dataset.sizes
DataArray.dims
"""
return Frozen(self._dims)
return FrozenMappingWarningOnValuesAccess(self._dims)

@property
def sizes(self) -> Frozen[Hashable, int]:
Expand All @@ -799,7 +801,7 @@ def sizes(self) -> Frozen[Hashable, int]:
--------
DataArray.sizes
"""
return self.dims
return Frozen(self._dims)

@property
def dtypes(self) -> Frozen[Hashable, np.dtype]:
Expand Down Expand Up @@ -1410,7 +1412,7 @@ def _copy_listed(self, names: Iterable[Hashable]) -> Self:
variables[name] = self._variables[name]
except KeyError:
ref_name, var_name, var = _get_virtual_variable(
self._variables, name, self.dims
self._variables, name, self.sizes
)
variables[var_name] = var
if ref_name in self._coord_names or ref_name in self.dims:
Expand All @@ -1425,7 +1427,7 @@ def _copy_listed(self, names: Iterable[Hashable]) -> Self:
for v in variables.values():
needed_dims.update(v.dims)

dims = {k: self.dims[k] for k in needed_dims}
dims = {k: self.sizes[k] for k in needed_dims}

# preserves ordering of coordinates
for k in self._variables:
Expand All @@ -1447,7 +1449,7 @@ def _construct_dataarray(self, name: Hashable) -> DataArray:
try:
variable = self._variables[name]
except KeyError:
_, name, variable = _get_virtual_variable(self._variables, name, self.dims)
_, name, variable = _get_virtual_variable(self._variables, name, self.sizes)

needed_dims = set(variable.dims)

Expand All @@ -1474,7 +1476,7 @@ def _item_sources(self) -> Iterable[Mapping[Hashable, Any]]:
yield HybridMappingProxy(keys=self._coord_names, mapping=self.coords)

# virtual coordinates
yield HybridMappingProxy(keys=self.dims, mapping=self)
yield HybridMappingProxy(keys=self.sizes, mapping=self)

def __contains__(self, key: object) -> bool:
"""The 'in' operator will return true or false depending on whether
Expand Down Expand Up @@ -2551,7 +2553,7 @@ def info(self, buf: IO | None = None) -> None:
lines = []
lines.append("xarray.Dataset {")
lines.append("dimensions:")
for name, size in self.dims.items():
for name, size in self.sizes.items():
lines.append(f"\t{name} = {size} ;")
lines.append("\nvariables:")
for name, da in self.variables.items():
Expand Down Expand Up @@ -2679,10 +2681,10 @@ def chunk(
else:
chunks_mapping = either_dict_or_kwargs(chunks, chunks_kwargs, "chunk")

bad_dims = chunks_mapping.keys() - self.dims.keys()
bad_dims = chunks_mapping.keys() - self.sizes.keys()
if bad_dims:
raise ValueError(
f"chunks keys {tuple(bad_dims)} not found in data dimensions {tuple(self.dims)}"
f"chunks keys {tuple(bad_dims)} not found in data dimensions {tuple(self.sizes.keys())}"
)

chunkmanager = guess_chunkmanager(chunked_array_type)
Expand Down Expand Up @@ -4158,7 +4160,7 @@ def _rename_vars(
return variables, coord_names

def _rename_dims(self, name_dict: Mapping[Any, Hashable]) -> dict[Hashable, int]:
return {name_dict.get(k, k): v for k, v in self.dims.items()}
return {name_dict.get(k, k): v for k, v in self.sizes.items()}

def _rename_indexes(
self, name_dict: Mapping[Any, Hashable], dims_dict: Mapping[Any, Hashable]
Expand Down Expand Up @@ -5150,7 +5152,7 @@ def _get_stack_index(
if dim in self._variables:
var = self._variables[dim]
else:
_, _, var = _get_virtual_variable(self._variables, dim, self.dims)
_, _, var = _get_virtual_variable(self._variables, dim, self.sizes)
# dummy index (only `stack_coords` will be used to construct the multi-index)
stack_index = PandasIndex([0], dim)
stack_coords = {dim: var}
Expand All @@ -5177,7 +5179,7 @@ def _stack_once(
if any(d in var.dims for d in dims):
add_dims = [d for d in dims if d not in var.dims]
vdims = list(var.dims) + add_dims
shape = [self.dims[d] for d in vdims]
shape = [self.sizes[d] for d in vdims]
exp_var = var.set_dims(vdims, shape)
stacked_var = exp_var.stack(**{new_dim: dims})
new_variables[name] = stacked_var
Expand Down Expand Up @@ -6319,15 +6321,15 @@ def dropna(
if subset is None:
subset = iter(self.data_vars)

count = np.zeros(self.dims[dim], dtype=np.int64)
count = np.zeros(self.sizes[dim], dtype=np.int64)
size = np.int_(0) # for type checking

for k in subset:
array = self._variables[k]
if dim in array.dims:
dims = [d for d in array.dims if d != dim]
count += np.asarray(array.count(dims))
size += math.prod([self.dims[d] for d in dims])
size += math.prod([self.sizes[d] for d in dims])

if thresh is not None:
mask = count >= thresh
Expand Down Expand Up @@ -7104,7 +7106,7 @@ def _normalize_dim_order(
f"Dataset: {list(self.dims)}"
)

ordered_dims = {k: self.dims[k] for k in dim_order}
ordered_dims = {k: self.sizes[k] for k in dim_order}

return ordered_dims

Expand Down Expand Up @@ -7364,7 +7366,7 @@ def to_dask_dataframe(
var = self.variables[name]
except KeyError:
# dimension without a matching coordinate
size = self.dims[name]
size = self.sizes[name]
data = da.arange(size, chunks=size, dtype=np.int64)
var = Variable((name,), data)

Expand Down
40 changes: 40 additions & 0 deletions xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,15 @@
Collection,
Container,
Hashable,
ItemsView,
Iterable,
Iterator,
KeysView,
Mapping,
MutableMapping,
MutableSet,
Sequence,
ValuesView,
)
from enum import Enum
from typing import (
Expand Down Expand Up @@ -473,6 +476,43 @@ def FrozenDict(*args, **kwargs) -> Frozen:
return Frozen(dict(*args, **kwargs))


class FrozenMappingWarningOnValuesAccess(Frozen[K, V]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very clever!

"""
Class which behaves like a Mapping but warns if the values are accessed.

Temporary object to aid in deprecation cycle of `Dataset.dims` (see GH issue #8496).
`Dataset.dims` is being changed from returning a mapping of dimension names to lengths to just
returning a frozen set of dimension names (to increase consistency with `DataArray.dims`).
This class retains backwards compatibility but raises a warning only if the return value
of ds.dims is used like a dictionary (i.e. it doesn't raise a warning if used in a way that
would also be valid for a FrozenSet, e.g. iteration).
"""

def _warn(self) -> None:
warnings.warn(
"The return type of `Dataset.dims` will be changed to return a set of dimension names in future, "
"in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, "
"please use `Dataset.sizes`.",
FutureWarning,
)

def __getitem__(self, key: K) -> V:
self._warn()
return super().__getitem__(key)

def keys(self) -> KeysView[K]:
self._warn()
return super().keys()

def items(self) -> ItemsView[K, V]:
self._warn()
return super().items()

def values(self) -> ValuesView[V]:
self._warn()
return super().values()


class HybridMappingProxy(Mapping[K, V]):
"""Implements the Mapping interface. Uses the wrapped mapping for item lookup
and a separate wrapped keys collection for iteration.
Expand Down
9 changes: 8 additions & 1 deletion xarray/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,18 @@ def source_ndarray(array):
return base


def format_record(record) -> str:
"""Format warning record like `FutureWarning('Function will be deprecated...')`"""
return f"{str(record.category)[8:-2]}('{record.message}'))"


@contextmanager
def assert_no_warnings():
with warnings.catch_warnings(record=True) as record:
yield record
assert len(record) == 0, "got unexpected warning(s)"
assert (
len(record) == 0
), f"Got {len(record)} unexpected warning(s): {[format_record(r) for r in record]}"


# Internal versions of xarray's test functions that validate additional
Expand Down
1 change: 1 addition & 0 deletions xarray/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1705,6 +1705,7 @@ def test_broadcasting_math(self):
v * w[0], Variable(["a", "b", "c", "d"], np.einsum("ab,cd->abcd", x, y[0]))
)

@pytest.mark.filterwarnings("ignore:Duplicate dimension names")
def test_broadcasting_failures(self):
a = Variable(["x"], np.arange(10))
b = Variable(["x"], np.arange(5))
Expand Down
Loading