""",
twitter_url="https://twitter.com/xarray_dev",
icon_links=[], # workaround for pydata/pydata-sphinx-theme#1220
+ announcement="🍾 Xarray is now 10 years old! 🎉",
)
# The name of an image file (relative to this directory) to place at the top
# of the sidebar.
-html_logo = "_static/dataset-diagram-logo.png"
+html_logo = "_static/logos/Xarray_Logo_RGB_Final.svg"
# The name of an image file (within the static path) to use as favicon of the
# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32
# pixels large.
-html_favicon = "_static/favicon.ico"
+html_favicon = "_static/logos/Xarray_Icon_Final.svg"
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
@@ -263,11 +264,11 @@
# configuration for sphinxext.opengraph
ogp_site_url = "https://docs.xarray.dev/en/latest/"
-ogp_image = "https://docs.xarray.dev/en/stable/_static/dataset-diagram-logo.png"
+ogp_image = "https://docs.xarray.dev/en/stable/_static/logos/Xarray_Logo_RGB_Final.png"
ogp_custom_meta_tags = [
'',
'',
- '',
+ '',
]
# Redirects for pages that were moved to new locations
@@ -325,6 +326,7 @@
"dask": ("https://docs.dask.org/en/latest", None),
"cftime": ("https://unidata.github.io/cftime", None),
"sparse": ("https://sparse.pydata.org/en/latest/", None),
+ "hypothesis": ("https://hypothesis.readthedocs.io/en/latest/", None),
"cubed": ("https://tom-e-white.com/cubed/", None),
"datatree": ("https://xarray-datatree.readthedocs.io/en/latest/", None),
"xarray-tutorial": ("https://tutorial.xarray.dev/", None),
diff --git a/doc/ecosystem.rst b/doc/ecosystem.rst
index fc5ae963a1d..561e9cdb5b2 100644
--- a/doc/ecosystem.rst
+++ b/doc/ecosystem.rst
@@ -78,6 +78,7 @@ Extend xarray capabilities
- `xarray-dataclasses `_: xarray extension for typed DataArray and Dataset creation.
- `xarray_einstats `_: Statistics, linear algebra and einops for xarray
- `xarray_extras `_: Advanced algorithms for xarray objects (e.g. integrations/interpolations).
+- `xeofs `_: PCA/EOF analysis and related techniques, integrated with xarray and Dask for efficient handling of large-scale data.
- `xpublish `_: Publish Xarray Datasets via a Zarr compatible REST API.
- `xrft `_: Fourier transforms for xarray data.
- `xr-scipy `_: A lightweight scipy wrapper for xarray.
diff --git a/doc/gallery.yml b/doc/gallery.yml
index f1a147dae87..f8316017d8c 100644
--- a/doc/gallery.yml
+++ b/doc/gallery.yml
@@ -25,12 +25,12 @@ notebooks-examples:
- title: Applying unvectorized functions with apply_ufunc
path: examples/apply_ufunc_vectorize_1d.html
- thumbnail: _static/dataset-diagram-square-logo.png
+ thumbnail: _static/logos/Xarray_Logo_RGB_Final.svg
external-examples:
- title: Managing raster data with rioxarray
path: https://corteva.github.io/rioxarray/stable/examples/examples.html
- thumbnail: _static/dataset-diagram-square-logo.png
+ thumbnail: _static/logos/Xarray_Logo_RGB_Final.svg
- title: Xarray and dask on the cloud with Pangeo
path: https://gallery.pangeo.io/
@@ -38,7 +38,7 @@ external-examples:
- title: Xarray with Dask Arrays
path: https://examples.dask.org/xarray.html_
- thumbnail: _static/dataset-diagram-square-logo.png
+ thumbnail: _static/logos/Xarray_Logo_RGB_Final.svg
- title: Project Pythia Foundations Book
path: https://foundations.projectpythia.org/core/xarray.html
diff --git a/doc/gallery/plot_cartopy_facetgrid.py b/doc/gallery/plot_cartopy_facetgrid.py
index d8f5e73ee56..a4ab23c42a6 100644
--- a/doc/gallery/plot_cartopy_facetgrid.py
+++ b/doc/gallery/plot_cartopy_facetgrid.py
@@ -30,7 +30,7 @@
transform=ccrs.PlateCarree(), # the data's projection
col="time",
col_wrap=1, # multiplot settings
- aspect=ds.dims["lon"] / ds.dims["lat"], # for a sensible figsize
+ aspect=ds.sizes["lon"] / ds.sizes["lat"], # for a sensible figsize
subplot_kws={"projection": map_proj}, # the plot's projection
)
diff --git a/doc/internals/duck-arrays-integration.rst b/doc/internals/duck-arrays-integration.rst
index a674acb04fe..43b17be8bb8 100644
--- a/doc/internals/duck-arrays-integration.rst
+++ b/doc/internals/duck-arrays-integration.rst
@@ -31,6 +31,8 @@ property needs to obey `numpy's broadcasting rules `_
of these same rules).
+.. _internals.duckarrays.array_api_standard:
+
Python Array API standard support
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
diff --git a/doc/user-guide/index.rst b/doc/user-guide/index.rst
index 0ac25d68930..45f0ce352de 100644
--- a/doc/user-guide/index.rst
+++ b/doc/user-guide/index.rst
@@ -25,4 +25,5 @@ examples that describe many common tasks that you can accomplish with xarray.
dask
plotting
options
+ testing
duckarrays
diff --git a/doc/user-guide/interpolation.rst b/doc/user-guide/interpolation.rst
index 7b40962e826..311e1bf0129 100644
--- a/doc/user-guide/interpolation.rst
+++ b/doc/user-guide/interpolation.rst
@@ -292,8 +292,8 @@ Let's see how :py:meth:`~xarray.DataArray.interp` works on real data.
axes[0].set_title("Raw data")
# Interpolated data
- new_lon = np.linspace(ds.lon[0], ds.lon[-1], ds.dims["lon"] * 4)
- new_lat = np.linspace(ds.lat[0], ds.lat[-1], ds.dims["lat"] * 4)
+ new_lon = np.linspace(ds.lon[0], ds.lon[-1], ds.sizes["lon"] * 4)
+ new_lat = np.linspace(ds.lat[0], ds.lat[-1], ds.sizes["lat"] * 4)
dsi = ds.interp(lat=new_lat, lon=new_lon)
dsi.air.plot(ax=axes[1])
@savefig interpolation_sample3.png width=8in
diff --git a/doc/user-guide/io.rst b/doc/user-guide/io.rst
index 2155ecfd88b..48751c5f299 100644
--- a/doc/user-guide/io.rst
+++ b/doc/user-guide/io.rst
@@ -804,7 +804,7 @@ store. These options are useful for scenarios when it is infeasible or
undesirable to write your entire dataset at once.
1. Use ``mode='a'`` to add or overwrite entire variables,
-2. Use ``append_dim`` to resize and append to exiting variables, and
+2. Use ``append_dim`` to resize and append to existing variables, and
3. Use ``region`` to write to limited regions of existing arrays.
.. tip::
diff --git a/doc/user-guide/terminology.rst b/doc/user-guide/terminology.rst
index ce7e55546a4..55937310827 100644
--- a/doc/user-guide/terminology.rst
+++ b/doc/user-guide/terminology.rst
@@ -47,9 +47,9 @@ complete examples, please consult the relevant documentation.*
all but one of these degrees of freedom is fixed. We can think of each
dimension axis as having a name, for example the "x dimension". In
xarray, a ``DataArray`` object's *dimensions* are its named dimension
- axes, and the name of the ``i``-th dimension is ``arr.dims[i]``. If an
- array is created without dimension names, the default dimension names are
- ``dim_0``, ``dim_1``, and so forth.
+ axes ``da.dims``, and the name of the ``i``-th dimension is ``da.dims[i]``.
+ If an array is created without specifying dimension names, the default dimension
+ names will be ``dim_0``, ``dim_1``, and so forth.
Coordinate
An array that labels a dimension or set of dimensions of another
@@ -61,8 +61,7 @@ complete examples, please consult the relevant documentation.*
``arr.coords[x]``. A ``DataArray`` can have more coordinates than
dimensions because a single dimension can be labeled by multiple
coordinate arrays. However, only one coordinate array can be a assigned
- as a particular dimension's dimension coordinate array. As a
- consequence, ``len(arr.dims) <= len(arr.coords)`` in general.
+ as a particular dimension's dimension coordinate array.
Dimension coordinate
A one-dimensional coordinate array assigned to ``arr`` with both a name
diff --git a/doc/user-guide/testing.rst b/doc/user-guide/testing.rst
new file mode 100644
index 00000000000..13279eccb0b
--- /dev/null
+++ b/doc/user-guide/testing.rst
@@ -0,0 +1,303 @@
+.. _testing:
+
+Testing your code
+=================
+
+.. ipython:: python
+ :suppress:
+
+ import numpy as np
+ import pandas as pd
+ import xarray as xr
+
+ np.random.seed(123456)
+
+.. _testing.hypothesis:
+
+Hypothesis testing
+------------------
+
+.. note::
+
+ Testing with hypothesis is a fairly advanced topic. Before reading this section it is recommended that you take a look
+ at our guide to xarray's :ref:`data structures`, are familiar with conventional unit testing in
+ `pytest `_, and have seen the
+ `hypothesis library documentation `_.
+
+`The hypothesis library `_ is a powerful tool for property-based testing.
+Instead of writing tests for one example at a time, it allows you to write tests parameterized by a source of many
+dynamically generated examples. For example you might have written a test which you wish to be parameterized by the set
+of all possible integers via :py:func:`hypothesis.strategies.integers()`.
+
+Property-based testing is extremely powerful, because (unlike more conventional example-based testing) it can find bugs
+that you did not even think to look for!
+
+Strategies
+~~~~~~~~~~
+
+Each source of examples is called a "strategy", and xarray provides a range of custom strategies which produce xarray
+data structures containing arbitrary data. You can use these to efficiently test downstream code,
+quickly ensuring that your code can handle xarray objects of all possible structures and contents.
+
+These strategies are accessible in the :py:mod:`xarray.testing.strategies` module, which provides
+
+.. currentmodule:: xarray
+
+.. autosummary::
+
+ testing.strategies.supported_dtypes
+ testing.strategies.names
+ testing.strategies.dimension_names
+ testing.strategies.dimension_sizes
+ testing.strategies.attrs
+ testing.strategies.variables
+ testing.strategies.unique_subset_of
+
+These build upon the numpy and array API strategies offered in :py:mod:`hypothesis.extra.numpy` and :py:mod:`hypothesis.extra.array_api`:
+
+.. ipython:: python
+
+ import hypothesis.extra.numpy as npst
+
+Generating Examples
+~~~~~~~~~~~~~~~~~~~
+
+To see an example of what each of these strategies might produce, you can call one followed by the ``.example()`` method,
+which is a general hypothesis method valid for all strategies.
+
+.. ipython:: python
+
+ import xarray.testing.strategies as xrst
+
+ xrst.variables().example()
+ xrst.variables().example()
+ xrst.variables().example()
+
+You can see that calling ``.example()`` multiple times will generate different examples, giving you an idea of the wide
+range of data that the xarray strategies can generate.
+
+In your tests however you should not use ``.example()`` - instead you should parameterize your tests with the
+:py:func:`hypothesis.given` decorator:
+
+.. ipython:: python
+
+ from hypothesis import given
+
+.. ipython:: python
+
+ @given(xrst.variables())
+ def test_function_that_acts_on_variables(var):
+ assert func(var) == ...
+
+
+Chaining Strategies
+~~~~~~~~~~~~~~~~~~~
+
+Xarray's strategies can accept other strategies as arguments, allowing you to customise the contents of the generated
+examples.
+
+.. ipython:: python
+
+ # generate a Variable containing an array with a complex number dtype, but all other details still arbitrary
+ from hypothesis.extra.numpy import complex_number_dtypes
+
+ xrst.variables(dtype=complex_number_dtypes()).example()
+
+This also works with custom strategies, or strategies defined in other packages.
+For example you could imagine creating a ``chunks`` strategy to specify particular chunking patterns for a dask-backed array.
+
+Fixing Arguments
+~~~~~~~~~~~~~~~~
+
+If you want to fix one aspect of the data structure, whilst allowing variation in the generated examples
+over all other aspects, then use :py:func:`hypothesis.strategies.just()`.
+
+.. ipython:: python
+
+ import hypothesis.strategies as st
+
+ # Generates only variable objects with dimensions ["x", "y"]
+ xrst.variables(dims=st.just(["x", "y"])).example()
+
+(This is technically another example of chaining strategies - :py:func:`hypothesis.strategies.just()` is simply a
+special strategy that just contains a single example.)
+
+To fix the length of dimensions you can instead pass ``dims`` as a mapping of dimension names to lengths
+(i.e. following xarray objects' ``.sizes()`` property), e.g.
+
+.. ipython:: python
+
+ # Generates only variables with dimensions ["x", "y"], of lengths 2 & 3 respectively
+ xrst.variables(dims=st.just({"x": 2, "y": 3})).example()
+
+You can also use this to specify that you want examples which are missing some part of the data structure, for instance
+
+.. ipython:: python
+
+ # Generates a Variable with no attributes
+ xrst.variables(attrs=st.just({})).example()
+
+Through a combination of chaining strategies and fixing arguments, you can specify quite complicated requirements on the
+objects your chained strategy will generate.
+
+.. ipython:: python
+
+ fixed_x_variable_y_maybe_z = st.fixed_dictionaries(
+ {"x": st.just(2), "y": st.integers(3, 4)}, optional={"z": st.just(2)}
+ )
+ fixed_x_variable_y_maybe_z.example()
+
+ special_variables = xrst.variables(dims=fixed_x_variable_y_maybe_z)
+
+ special_variables.example()
+ special_variables.example()
+
+Here we have used one of hypothesis' built-in strategies :py:func:`hypothesis.strategies.fixed_dictionaries` to create a
+strategy which generates mappings of dimension names to lengths (i.e. the ``size`` of the xarray object we want).
+This particular strategy will always generate an ``x`` dimension of length 2, and a ``y`` dimension of
+length either 3 or 4, and will sometimes also generate a ``z`` dimension of length 2.
+By feeding this strategy for dictionaries into the ``dims`` argument of xarray's :py:func:`~st.variables` strategy,
+we can generate arbitrary :py:class:`~xarray.Variable` objects whose dimensions will always match these specifications.
+
+Generating Duck-type Arrays
+~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Xarray objects don't have to wrap numpy arrays, in fact they can wrap any array type which presents the same API as a
+numpy array (so-called "duck array wrapping", see :ref:`wrapping numpy-like arrays `).
+
+Imagine we want to write a strategy which generates arbitrary ``Variable`` objects, each of which wraps a
+:py:class:`sparse.COO` array instead of a ``numpy.ndarray``. How could we do that? There are two ways:
+
+1. Create a xarray object with numpy data and use the hypothesis' ``.map()`` method to convert the underlying array to a
+different type:
+
+.. ipython:: python
+
+ import sparse
+
+.. ipython:: python
+
+ def convert_to_sparse(var):
+ return var.copy(data=sparse.COO.from_numpy(var.to_numpy()))
+
+.. ipython:: python
+
+ sparse_variables = xrst.variables(dims=xrst.dimension_names(min_dims=1)).map(
+ convert_to_sparse
+ )
+
+ sparse_variables.example()
+ sparse_variables.example()
+
+2. Pass a function which returns a strategy which generates the duck-typed arrays directly to the ``array_strategy_fn`` argument of the xarray strategies:
+
+.. ipython:: python
+
+ def sparse_random_arrays(shape: tuple[int]) -> sparse._coo.core.COO:
+ """Strategy which generates random sparse.COO arrays"""
+ if shape is None:
+ shape = npst.array_shapes()
+ else:
+ shape = st.just(shape)
+ density = st.integers(min_value=0, max_value=1)
+ # note sparse.random does not accept a dtype kwarg
+ return st.builds(sparse.random, shape=shape, density=density)
+
+
+ def sparse_random_arrays_fn(
+ *, shape: tuple[int, ...], dtype: np.dtype
+ ) -> st.SearchStrategy[sparse._coo.core.COO]:
+ return sparse_random_arrays(shape=shape)
+
+
+.. ipython:: python
+
+ sparse_random_variables = xrst.variables(
+ array_strategy_fn=sparse_random_arrays_fn, dtype=st.just(np.dtype("float64"))
+ )
+ sparse_random_variables.example()
+
+Either approach is fine, but one may be more convenient than the other depending on the type of the duck array which you
+want to wrap.
+
+Compatibility with the Python Array API Standard
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Xarray aims to be compatible with any duck-array type that conforms to the `Python Array API Standard `_
+(see our :ref:`docs on Array API Standard support `).
+
+.. warning::
+
+ The strategies defined in :py:mod:`testing.strategies` are **not** guaranteed to use array API standard-compliant
+ dtypes by default.
+ For example arrays with the dtype ``np.dtype('float16')`` may be generated by :py:func:`testing.strategies.variables`
+ (assuming the ``dtype`` kwarg was not explicitly passed), despite ``np.dtype('float16')`` not being in the
+ array API standard.
+
+If the array type you want to generate has an array API-compliant top-level namespace
+(e.g. that which is conventionally imported as ``xp`` or similar),
+you can use this neat trick:
+
+.. ipython:: python
+ :okwarning:
+
+ from numpy import array_api as xp # available in numpy 1.26.0
+
+ from hypothesis.extra.array_api import make_strategies_namespace
+
+ xps = make_strategies_namespace(xp)
+
+ xp_variables = xrst.variables(
+ array_strategy_fn=xps.arrays,
+ dtype=xps.scalar_dtypes(),
+ )
+ xp_variables.example()
+
+Another array API-compliant duck array library would replace the import, e.g. ``import cupy as cp`` instead.
+
+Testing over Subsets of Dimensions
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+A common task when testing xarray user code is checking that your function works for all valid input dimensions.
+We can chain strategies to achieve this, for which the helper strategy :py:func:`~testing.strategies.unique_subset_of`
+is useful.
+
+It works for lists of dimension names
+
+.. ipython:: python
+
+ dims = ["x", "y", "z"]
+ xrst.unique_subset_of(dims).example()
+ xrst.unique_subset_of(dims).example()
+
+as well as for mappings of dimension names to sizes
+
+.. ipython:: python
+
+ dim_sizes = {"x": 2, "y": 3, "z": 4}
+ xrst.unique_subset_of(dim_sizes).example()
+ xrst.unique_subset_of(dim_sizes).example()
+
+This is useful because operations like reductions can be performed over any subset of the xarray object's dimensions.
+For example we can write a pytest test that tests that a reduction gives the expected result when applying that reduction
+along any possible valid subset of the Variable's dimensions.
+
+.. code-block:: python
+
+ import numpy.testing as npt
+
+
+ @given(st.data(), xrst.variables(dims=xrst.dimension_names(min_dims=1)))
+ def test_mean(data, var):
+ """Test that the mean of an xarray Variable is always equal to the mean of the underlying array."""
+
+ # specify arbitrary reduction along at least one dimension
+ reduction_dims = data.draw(xrst.unique_subset_of(var.dims, min_size=1))
+
+ # create expected result (using nanmean because arrays with Nans will be generated)
+ reduction_axes = tuple(var.get_axis_num(dim) for dim in reduction_dims)
+ expected = np.nanmean(var.data, axis=reduction_axes)
+
+ # assert property is always satisfied
+ result = var.mean(dim=reduction_dims).data
+ npt.assert_equal(expected, result)
diff --git a/doc/user-guide/time-series.rst b/doc/user-guide/time-series.rst
index cbb831cac3a..82172aa8998 100644
--- a/doc/user-guide/time-series.rst
+++ b/doc/user-guide/time-series.rst
@@ -245,6 +245,18 @@ Data that has indices outside of the given ``tolerance`` are set to ``NaN``.
ds.resample(time="1h").nearest(tolerance="1h")
+It is often desirable to center the time values after a resampling operation.
+That can be accomplished by updating the resampled dataset time coordinate values
+using time offset arithmetic via the `pandas.tseries.frequencies.to_offset`_ function.
+
+.. _pandas.tseries.frequencies.to_offset: https://pandas.pydata.org/docs/reference/api/pandas.tseries.frequencies.to_offset.html
+
+.. ipython:: python
+
+ resampled_ds = ds.resample(time="6h").mean()
+ offset = pd.tseries.frequencies.to_offset("6h") / 2
+ resampled_ds["time"] = resampled_ds.get_index("time") + offset
+ resampled_ds
For more examples of using grouped operations on a time dimension, see
:doc:`../examples/weather-data`.
diff --git a/doc/whats-new.rst b/doc/whats-new.rst
index 2fb76cfe8c2..c0917b7443b 100644
--- a/doc/whats-new.rst
+++ b/doc/whats-new.rst
@@ -14,14 +14,179 @@ What's New
np.random.seed(123456)
-.. _whats-new.2023.10.2:
-v2023.10.2 (unreleased)
+
+.. _whats-new.2023.12.1:
+
+v2023.12.1 (unreleased)
-----------------------
New Features
~~~~~~~~~~~~
+- :py:meth:`xr.cov` and :py:meth:`xr.corr` now support using weights (:issue:`8527`, :pull:`7392`).
+ By `Llorenç Lledó `_.
+
+Breaking changes
+~~~~~~~~~~~~~~~~
+
+
+Deprecations
+~~~~~~~~~~~~
+
+
+Bug fixes
+~~~~~~~~~
+
+
+Documentation
+~~~~~~~~~~~~~
+
+
+Internal Changes
+~~~~~~~~~~~~~~~~
+
+- Remove null values before plotting. (:pull:`8535`).
+ By `Jimmy Westling `_.
+
+.. _whats-new.2023.12.0:
+
+v2023.12.0 (2023 Dec 08)
+------------------------
+
+This release brings new `hypothesis `_ strategies for testing, significantly faster rolling aggregations as well as
+``ffill`` and ``bfill`` with ``numbagg``, a new :py:meth:`Dataset.eval` method, and improvements to
+reading and writing Zarr arrays (including a new ``"a-"`` mode).
+
+Thanks to our 16 contributors:
+
+Anderson Banihirwe, Ben Mares, Carl Andersson, Deepak Cherian, Doug Latornell, Gregorio L. Trevisan, Illviljan, Jens Hedegaard Nielsen, Justus Magin, Mathias Hauser, Max Jones, Maximilian Roos, Michael Niklas, Patrick Hoefler, Ryan Abernathey, Tom Nicholas
+
+New Features
+~~~~~~~~~~~~
+
+- Added hypothesis strategies for generating :py:class:`xarray.Variable` objects containing arbitrary data, useful for parametrizing downstream tests.
+ Accessible under :py:mod:`testing.strategies`, and documented in a new page on testing in the User Guide.
+ (:issue:`6911`, :pull:`8404`)
+ By `Tom Nicholas `_.
+- :py:meth:`rolling` uses `numbagg `_ for
+ most of its computations by default. Numbagg is up to 5x faster than bottleneck
+ where parallelization is possible. Where parallelization isn't possible — for
+ example a 1D array — it's about the same speed as bottleneck, and 2-5x faster
+ than pandas' default functions. (:pull:`8493`). numbagg is an optional
+ dependency, so requires installing separately.
+- Add :py:meth:`DataArray.cumulative` & :py:meth:`Dataset.cumulative` to compute
+ cumulative aggregations, such as ``sum``, along a dimension — for example
+ ``da.cumulative('time').sum()``. This is similar to pandas' ``.expanding``,
+ and mostly equivalent to ``.cumsum`` methods, or to
+ :py:meth:`DataArray.rolling` with a window length equal to the dimension size.
+ (:pull:`8512`).
+ By `Maximilian Roos `_.
+- Use a concise format when plotting datetime arrays. (:pull:`8449`).
+ By `Jimmy Westling `_.
+- Avoid overwriting unchanged existing coordinate variables when appending with :py:meth:`Dataset.to_zarr` by setting ``mode='a-'``.
+ By `Ryan Abernathey `_ and `Deepak Cherian `_.
+- :py:meth:`~xarray.DataArray.rank` now operates on dask-backed arrays, assuming
+ the core dim has exactly one chunk. (:pull:`8475`).
+ By `Maximilian Roos `_.
+- Add a :py:meth:`Dataset.eval` method, similar to the pandas' method of the
+ same name. (:pull:`7163`). This is currently marked as experimental and
+ doesn't yet support the ``numexpr`` engine.
+- :py:meth:`Dataset.drop_vars` & :py:meth:`DataArray.drop_vars` allow passing a
+ callable, similar to :py:meth:`Dataset.where` & :py:meth:`Dataset.sortby` & others.
+ (:pull:`8511`).
+ By `Maximilian Roos `_.
+
+Breaking changes
+~~~~~~~~~~~~~~~~
+
+- Explicitly warn when creating xarray objects with repeated dimension names.
+ Such objects will also now raise when :py:meth:`DataArray.get_axis_num` is called,
+ which means many functions will raise.
+ This latter change is technically a breaking change, but whilst allowed,
+ this behaviour was never actually supported! (:issue:`3731`, :pull:`8491`)
+ By `Tom Nicholas `_.
+
+Deprecations
+~~~~~~~~~~~~
+
+- As part of an effort to standardize the API, we're renaming the ``dims``
+ keyword arg to ``dim`` for the minority of functions which current use
+ ``dims``. This started with :py:func:`xarray.dot` & :py:meth:`DataArray.dot`
+ and we'll gradually roll this out across all functions. The warnings are
+ currently ``PendingDeprecationWarning``, which are silenced by default. We'll
+ convert these to ``DeprecationWarning`` in a future release.
+ By `Maximilian Roos `_.
+- Raise a ``FutureWarning`` warning that the type of :py:meth:`Dataset.dims` will be changed
+ from a mapping of dimension names to lengths to a set of dimension names.
+ This is to increase consistency with :py:meth:`DataArray.dims`.
+ To access a mapping of dimension names to lengths please use :py:meth:`Dataset.sizes`.
+ The same change also applies to `DatasetGroupBy.dims`.
+ (:issue:`8496`, :pull:`8500`)
+ By `Tom Nicholas `_.
+- :py:meth:`Dataset.drop` & :py:meth:`DataArray.drop` are now deprecated, since pending deprecation for
+ several years. :py:meth:`DataArray.drop_sel` & :py:meth:`DataArray.drop_var`
+ replace them for labels & variables respectively. (:pull:`8497`)
+ By `Maximilian Roos `_.
+
+Bug fixes
+~~~~~~~~~
+
+- Fix dtype inference for ``pd.CategoricalIndex`` when categories are backed by a ``pd.ExtensionDtype`` (:pull:`8481`)
+- Fix writing a variable that requires transposing when not writing to a region (:pull:`8484`)
+ By `Maximilian Roos `_.
+- Static typing of ``p0`` and ``bounds`` arguments of :py:func:`xarray.DataArray.curvefit` and :py:func:`xarray.Dataset.curvefit`
+ was changed to ``Mapping`` (:pull:`8502`).
+ By `Michael Niklas `_.
+- Fix typing of :py:func:`xarray.DataArray.to_netcdf` and :py:func:`xarray.Dataset.to_netcdf`
+ when ``compute`` is evaluated to bool instead of a Literal (:pull:`8268`).
+ By `Jens Hedegaard Nielsen `_.
+
+Documentation
+~~~~~~~~~~~~~
+
+- Added illustration of updating the time coordinate values of a resampled dataset using
+ time offset arithmetic.
+ This is the recommended technique to replace the use of the deprecated ``loffset`` parameter
+ in ``resample`` (:pull:`8479`).
+ By `Doug Latornell `_.
+- Improved error message when attempting to get a variable which doesn't exist from a Dataset.
+ (:pull:`8474`)
+ By `Maximilian Roos `_.
+- Fix default value of ``combine_attrs`` in :py:func:`xarray.combine_by_coords` (:pull:`8471`)
+ By `Gregorio L. Trevisan `_.
+
+Internal Changes
+~~~~~~~~~~~~~~~~
+
+- :py:meth:`DataArray.bfill` & :py:meth:`DataArray.ffill` now use numbagg `_ by
+ default, which is up to 5x faster where parallelization is possible. (:pull:`8339`)
+ By `Maximilian Roos `_.
+- Update mypy version to 1.7 (:issue:`8448`, :pull:`8501`).
+ By `Michael Niklas `_.
+
+.. _whats-new.2023.11.0:
+
+v2023.11.0 (Nov 16, 2023)
+-------------------------
+
+
+.. tip::
+
+ `This is our 10th year anniversary release! `_ Thank you for your love and support.
+
+
+This release brings the ability to use ``opt_einsum`` for :py:func:`xarray.dot` by default,
+support for auto-detecting ``region`` when writing partial datasets to Zarr, and the use of h5py
+drivers with ``h5netcdf``.
+
+Thanks to the 19 contributors to this release:
+Aman Bagrecha, Anderson Banihirwe, Ben Mares, Deepak Cherian, Dimitri Papadopoulos Orfanos, Ezequiel Cimadevilla Alvarez,
+Illviljan, Justus Magin, Katelyn FitzGerald, Kai Muehlbauer, Martin Durant, Maximilian Roos, Metamess, Sam Levang, Spencer Clark, Tom Nicholas, mgunyho, templiert
+
+New Features
+~~~~~~~~~~~~
+
- Use `opt_einsum `_ for :py:func:`xarray.dot` by default if installed.
By `Deepak Cherian `_. (:issue:`7764`, :pull:`8373`).
- Add ``DataArray.dt.total_seconds()`` method to match the Pandas API. (:pull:`8435`).
@@ -32,13 +197,14 @@ New Features
By `Sam Levang `_.
- Allow the usage of h5py drivers (eg: ros3) via h5netcdf (:pull:`8360`).
By `Ezequiel Cimadevilla `_.
+- Enable VLEN string fill_values, preserve VLEN string dtypes (:issue:`1647`, :issue:`7652`, :issue:`7868`, :pull:`7869`).
+ By `Kai Mühlbauer `_.
Breaking changes
~~~~~~~~~~~~~~~~
- drop support for `cdms2 `_. Please use
`xcdat `_ instead (:pull:`8441`).
By `Justus Magin `_.
-
- Following pandas, :py:meth:`infer_freq` will return ``"Y"``, ``"YS"``,
``"QE"``, ``"ME"``, ``"h"``, ``"min"``, ``"s"``, ``"ms"``, ``"us"``, or
``"ns"`` instead of ``"A"``, ``"AS"``, ``"Q"``, ``"M"``, ``"H"``, ``"T"``,
@@ -46,6 +212,8 @@ Breaking changes
deprecation of the latter frequency strings (:issue:`8394`, :pull:`8415`). By
`Spencer Clark `_.
- Bump minimum tested pint version to ``>=0.22``. By `Deepak Cherian `_.
+- Minimum supported versions for the following packages have changed: ``h5py >=3.7``, ``h5netcdf>=1.1``.
+ By `Kai Mühlbauer `_.
Deprecations
~~~~~~~~~~~~
@@ -91,17 +259,15 @@ Bug fixes
- Fix a bug where :py:meth:`DataArray.to_dataset` silently drops a variable
if a coordinate with the same name already exists (:pull:`8433`, :issue:`7823`).
By `András Gunyhó `_.
+- Fix for :py:meth:`DataArray.to_zarr` & :py:meth:`Dataset.to_zarr` to close
+ the created zarr store when passing a path with `.zip` extension (:pull:`8425`).
+ By `Carl Andersson _`.
Documentation
~~~~~~~~~~~~~
- Small updates to documentation on distributed writes: See :ref:`io.zarr.appending` to Zarr.
By `Deepak Cherian `_.
-
-Internal Changes
-~~~~~~~~~~~~~~~~
-
-
.. _whats-new.2023.10.1:
v2023.10.1 (19 Oct, 2023)
@@ -423,6 +589,10 @@ Internal Changes
- :py:func:`as_variable` now consistently includes the variable name in any exceptions
raised. (:pull:`7995`). By `Peter Hill `_
+- Redirect cumulative reduction functions internally through the :py:class:`ChunkManagerEntryPoint`,
+ potentially allowing :py:meth:`~xarray.DataArray.ffill` and :py:meth:`~xarray.DataArray.bfill` to
+ use non-dask chunked array types.
+ (:pull:`8019`) By `Tom Nicholas `_.
- :py:func:`encode_dataset_coordinates` now sorts coordinates automatically assigned to
`coordinates` attributes during serialization (:issue:`8026`, :pull:`8034`).
`By Ian Carroll `_.
diff --git a/xarray/backends/api.py b/xarray/backends/api.py
index 3e6d00a8059..1d538bf94ed 100644
--- a/xarray/backends/api.py
+++ b/xarray/backends/api.py
@@ -39,6 +39,7 @@
from xarray.core.dataset import Dataset, _get_chunk, _maybe_chunk
from xarray.core.indexes import Index
from xarray.core.parallelcompat import guess_chunkmanager
+from xarray.core.types import ZarrWriteModes
from xarray.core.utils import is_remote_uri
if TYPE_CHECKING:
@@ -69,7 +70,6 @@
"NETCDF4", "NETCDF4_CLASSIC", "NETCDF3_64BIT", "NETCDF3_CLASSIC"
]
-
DATAARRAY_NAME = "__xarray_dataarray_name__"
DATAARRAY_VARIABLE = "__xarray_dataarray_variable__"
@@ -1160,6 +1160,62 @@ def to_netcdf(
...
+# if compute cannot be evaluated at type check time
+# we may get back either Delayed or None
+@overload
+def to_netcdf(
+ dataset: Dataset,
+ path_or_file: str | os.PathLike,
+ mode: Literal["w", "a"] = "w",
+ format: T_NetcdfTypes | None = None,
+ group: str | None = None,
+ engine: T_NetcdfEngine | None = None,
+ encoding: Mapping[Hashable, Mapping[str, Any]] | None = None,
+ unlimited_dims: Iterable[Hashable] | None = None,
+ compute: bool = False,
+ multifile: Literal[False] = False,
+ invalid_netcdf: bool = False,
+) -> Delayed | None:
+ ...
+
+
+# if multifile cannot be evaluated at type check time
+# we may get back either writer and datastore or Delayed or None
+@overload
+def to_netcdf(
+ dataset: Dataset,
+ path_or_file: str | os.PathLike,
+ mode: Literal["w", "a"] = "w",
+ format: T_NetcdfTypes | None = None,
+ group: str | None = None,
+ engine: T_NetcdfEngine | None = None,
+ encoding: Mapping[Hashable, Mapping[str, Any]] | None = None,
+ unlimited_dims: Iterable[Hashable] | None = None,
+ compute: bool = False,
+ multifile: bool = False,
+ invalid_netcdf: bool = False,
+) -> tuple[ArrayWriter, AbstractDataStore] | Delayed | None:
+ ...
+
+
+# Any
+@overload
+def to_netcdf(
+ dataset: Dataset,
+ path_or_file: str | os.PathLike | None,
+ mode: Literal["w", "a"] = "w",
+ format: T_NetcdfTypes | None = None,
+ group: str | None = None,
+ engine: T_NetcdfEngine | None = None,
+ encoding: Mapping[Hashable, Mapping[str, Any]] | None = None,
+ unlimited_dims: Iterable[Hashable] | None = None,
+ compute: bool = False,
+ multifile: bool = False,
+ invalid_netcdf: bool = False,
+) -> tuple[ArrayWriter, AbstractDataStore] | bytes | Delayed | None:
+ ...
+
+
def to_netcdf(
dataset: Dataset,
path_or_file: str | os.PathLike | None = None,
@@ -1577,7 +1633,7 @@ def to_zarr(
dataset: Dataset,
store: MutableMapping | str | os.PathLike[str] | None = None,
chunk_store: MutableMapping | str | os.PathLike | None = None,
- mode: Literal["w", "w-", "a", "r+", None] = None,
+ mode: ZarrWriteModes | None = None,
synchronizer=None,
group: str | None = None,
encoding: Mapping | None = None,
@@ -1601,7 +1657,7 @@ def to_zarr(
dataset: Dataset,
store: MutableMapping | str | os.PathLike[str] | None = None,
chunk_store: MutableMapping | str | os.PathLike | None = None,
- mode: Literal["w", "w-", "a", "r+", None] = None,
+ mode: ZarrWriteModes | None = None,
synchronizer=None,
group: str | None = None,
encoding: Mapping | None = None,
@@ -1623,7 +1679,7 @@ def to_zarr(
dataset: Dataset,
store: MutableMapping | str | os.PathLike[str] | None = None,
chunk_store: MutableMapping | str | os.PathLike | None = None,
- mode: Literal["w", "w-", "a", "r+", None] = None,
+ mode: ZarrWriteModes | None = None,
synchronizer=None,
group: str | None = None,
encoding: Mapping | None = None,
@@ -1680,16 +1736,18 @@ def to_zarr(
else:
mode = "w-"
- if mode != "a" and append_dim is not None:
+ if mode not in ["a", "a-"] and append_dim is not None:
raise ValueError("cannot set append_dim unless mode='a' or mode=None")
- if mode not in ["a", "r+"] and region is not None:
- raise ValueError("cannot set region unless mode='a', mode='r+' or mode=None")
+ if mode not in ["a", "a-", "r+"] and region is not None:
+ raise ValueError(
+ "cannot set region unless mode='a', mode='a-', mode='r+' or mode=None"
+ )
- if mode not in ["w", "w-", "a", "r+"]:
+ if mode not in ["w", "w-", "a", "a-", "r+"]:
raise ValueError(
"The only supported options for mode are 'w', "
- f"'w-', 'a' and 'r+', but mode={mode!r}"
+ f"'w-', 'a', 'a-', and 'r+', but mode={mode!r}"
)
# validate Dataset keys, DataArray names
@@ -1745,7 +1803,7 @@ def to_zarr(
write_empty=write_empty_chunks,
)
- if mode in ["a", "r+"]:
+ if mode in ["a", "a-", "r+"]:
_validate_datatypes_for_zarr_append(zstore, dataset)
if append_dim is not None:
existing_dims = zstore.get_dimensions()
diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py
index a68a44b5f6f..d9385fc68a9 100644
--- a/xarray/backends/h5netcdf_.py
+++ b/xarray/backends/h5netcdf_.py
@@ -271,15 +271,6 @@ def prepare_variable(
dtype = _get_datatype(variable, raise_on_invalid_encoding=check_encoding)
fillvalue = attrs.pop("_FillValue", None)
- if dtype is str and fillvalue is not None:
- raise NotImplementedError(
- "h5netcdf does not yet support setting a fill value for "
- "variable-length strings "
- "(https://github.com/h5netcdf/h5netcdf/issues/37). "
- f"Either remove '_FillValue' from encoding on variable {name!r} "
- "or set {'dtype': 'S1'} in encoding to use the fixed width "
- "NC_CHAR type."
- )
if dtype is str:
dtype = h5py.special_dtype(vlen=str)
diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py
index f21f15bf795..1aee4c1c726 100644
--- a/xarray/backends/netCDF4_.py
+++ b/xarray/backends/netCDF4_.py
@@ -490,16 +490,6 @@ def prepare_variable(
fill_value = attrs.pop("_FillValue", None)
- if datatype is str and fill_value is not None:
- raise NotImplementedError(
- "netCDF4 does not yet support setting a fill value for "
- "variable-length strings "
- "(https://github.com/Unidata/netcdf4-python/issues/730). "
- f"Either remove '_FillValue' from encoding on variable {name!r} "
- "or set {'dtype': 'S1'} in encoding to use the fixed width "
- "NC_CHAR type."
- )
-
encoding = _extract_nc4_variable_encoding(
variable, raise_on_invalid=check_encoding, unlimited_dims=unlimited_dims
)
diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py
index 656f91c604e..2cd5401795c 100644
--- a/xarray/backends/zarr.py
+++ b/xarray/backends/zarr.py
@@ -22,6 +22,7 @@
from xarray.core.common import zeros_like
from xarray.core.parallelcompat import guess_chunkmanager
from xarray.core.pycompat import integer_types
+from xarray.core.types import ZarrWriteModes
from xarray.core.utils import (
FrozenDict,
HiddenKeyDict,
@@ -223,15 +224,12 @@ def encode_zarr_attr_value(value):
class ZarrArrayWrapper(BackendArray):
- __slots__ = ("datastore", "dtype", "shape", "variable_name", "_array")
-
- def __init__(self, variable_name, datastore):
- self.datastore = datastore
- self.variable_name = variable_name
+ __slots__ = ("dtype", "shape", "_array")
+ def __init__(self, zarr_array):
# some callers attempt to evaluate an array if an `array` property exists on the object.
# we prefix with _ to avoid this inference.
- self._array = self.datastore.zarr_group[self.variable_name]
+ self._array = zarr_array
self.shape = self._array.shape
# preserve vlen string object dtype (GH 7328)
@@ -248,10 +246,10 @@ def get_array(self):
return self._array
def _oindex(self, key):
- return self.get_array().oindex[key]
+ return self._array.oindex[key]
def __getitem__(self, key):
- array = self.get_array()
+ array = self._array
if isinstance(key, indexing.BasicIndexer):
return array[key.tuple]
elif isinstance(key, indexing.VectorizedIndexer):
@@ -339,8 +337,8 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks):
# DESIGN CHOICE: do not allow multiple dask chunks on a single zarr chunk
# this avoids the need to get involved in zarr synchronization / locking
# From zarr docs:
- # "If each worker in a parallel computation is writing to a separate
- # region of the array, and if region boundaries are perfectly aligned
+ # "If each worker in a parallel computation is writing to a
+ # separate region of the array, and if region boundaries are perfectly aligned
# with chunk boundaries, then no synchronization is required."
# TODO: incorporate synchronizer to allow writes from multiple dask
# threads
@@ -543,13 +541,14 @@ class ZarrStore(AbstractWritableDataStore):
"_write_region",
"_safe_chunks",
"_write_empty",
+ "_close_store_on_close",
)
@classmethod
def open_group(
cls,
store,
- mode="r",
+ mode: ZarrWriteModes = "r",
synchronizer=None,
group=None,
consolidated=False,
@@ -574,7 +573,8 @@ def open_group(
zarr_version = getattr(store, "_store_version", 2)
open_kwargs = dict(
- mode=mode,
+ # mode='a-' is a handcrafted xarray specialty
+ mode="a" if mode == "a-" else mode,
synchronizer=synchronizer,
path=group,
)
@@ -626,6 +626,7 @@ def open_group(
zarr_group = zarr.open_consolidated(store, **open_kwargs)
else:
zarr_group = zarr.open_group(store, **open_kwargs)
+ close_store_on_close = zarr_group.store is not store
return cls(
zarr_group,
mode,
@@ -634,6 +635,7 @@ def open_group(
write_region,
safe_chunks,
write_empty,
+ close_store_on_close,
)
def __init__(
@@ -645,6 +647,7 @@ def __init__(
write_region=None,
safe_chunks=True,
write_empty: bool | None = None,
+ close_store_on_close: bool = False,
):
self.zarr_group = zarr_group
self._read_only = self.zarr_group.read_only
@@ -656,6 +659,7 @@ def __init__(
self._write_region = write_region
self._safe_chunks = safe_chunks
self._write_empty = write_empty
+ self._close_store_on_close = close_store_on_close
@property
def ds(self):
@@ -663,7 +667,7 @@ def ds(self):
return self.zarr_group
def open_store_variable(self, name, zarr_array):
- data = indexing.LazilyIndexedArray(ZarrArrayWrapper(name, self))
+ data = indexing.LazilyIndexedArray(ZarrArrayWrapper(zarr_array))
try_nczarr = self._mode == "r"
dimensions, attributes = _get_zarr_dims_and_attrs(
zarr_array, DIMENSION_KEY, try_nczarr
@@ -761,8 +765,9 @@ def store(
"""
import zarr
+ existing_keys = tuple(self.zarr_group.array_keys())
existing_variable_names = {
- vn for vn in variables if _encode_variable_name(vn) in self.zarr_group
+ vn for vn in variables if _encode_variable_name(vn) in existing_keys
}
new_variables = set(variables) - existing_variable_names
variables_without_encoding = {vn: variables[vn] for vn in new_variables}
@@ -786,12 +791,10 @@ def store(
variables_encoded.update(vars_with_encoding)
for var_name in existing_variable_names:
- new_var = variables_encoded[var_name]
- existing_var = existing_vars[var_name]
- new_var = _validate_and_transpose_existing_dims(
+ variables_encoded[var_name] = _validate_and_transpose_existing_dims(
var_name,
- new_var,
- existing_var,
+ variables_encoded[var_name],
+ existing_vars[var_name],
self._write_region,
self._append_dim,
)
@@ -800,8 +803,21 @@ def store(
self.set_attributes(attributes)
self.set_dimensions(variables_encoded, unlimited_dims=unlimited_dims)
+ # if we are appending to an append_dim, only write either
+ # - new variables not already present, OR
+ # - variables with the append_dim in their dimensions
+ # We do NOT overwrite other variables.
+ if self._mode == "a-" and self._append_dim is not None:
+ variables_to_set = {
+ k: v
+ for k, v in variables_encoded.items()
+ if (k not in existing_variable_names) or (self._append_dim in v.dims)
+ }
+ else:
+ variables_to_set = variables_encoded
+
self.set_variables(
- variables_encoded, check_encoding_set, writer, unlimited_dims=unlimited_dims
+ variables_to_set, check_encoding_set, writer, unlimited_dims=unlimited_dims
)
if self._consolidate_on_close:
zarr.consolidate_metadata(self.zarr_group.store)
@@ -829,6 +845,8 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
import zarr
+ existing_keys = tuple(self.zarr_group.array_keys())
+
for vn, v in variables.items():
name = _encode_variable_name(vn)
check = vn in check_encoding_set
@@ -841,7 +859,7 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
if v.encoding == {"_FillValue": None} and fill_value is None:
v.encoding = {}
- if name in self.zarr_group:
+ if name in existing_keys:
# existing variable
# TODO: if mode="a", consider overriding the existing variable
# metadata. This would need some case work properly with region
@@ -923,7 +941,8 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
writer.add(v.data, zarr_array, region)
def close(self):
- pass
+ if self._close_store_on_close:
+ self.zarr_group.store.close()
def open_zarr(
diff --git a/xarray/coding/variables.py b/xarray/coding/variables.py
index df660f90d9e..487197605e8 100644
--- a/xarray/coding/variables.py
+++ b/xarray/coding/variables.py
@@ -562,3 +562,15 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable:
def decode(self):
raise NotImplementedError()
+
+
+class ObjectVLenStringCoder(VariableCoder):
+ def encode(self):
+ return NotImplementedError
+
+ def decode(self, variable: Variable, name: T_Name = None) -> Variable:
+ if variable.dtype == object and variable.encoding.get("dtype", False) == str:
+ variable = variable.astype(variable.encoding["dtype"])
+ return variable
+ else:
+ return variable
diff --git a/xarray/conventions.py b/xarray/conventions.py
index cf207f0c37a..8c7d6be2309 100644
--- a/xarray/conventions.py
+++ b/xarray/conventions.py
@@ -52,16 +52,32 @@ def _var_as_tuple(var: Variable) -> T_VarTuple:
return var.dims, var.data, var.attrs.copy(), var.encoding.copy()
-def _infer_dtype(array, name: T_Name = None) -> np.dtype:
- """Given an object array with no missing values, infer its dtype from its
- first element
- """
+def _infer_dtype(array, name=None):
+ """Given an object array with no missing values, infer its dtype from all elements."""
if array.dtype.kind != "O":
raise TypeError("infer_type must be called on a dtype=object array")
if array.size == 0:
return np.dtype(float)
+ native_dtypes = set(np.vectorize(type, otypes=[object])(array.ravel()))
+ if len(native_dtypes) > 1 and native_dtypes != {bytes, str}:
+ raise ValueError(
+ "unable to infer dtype on variable {!r}; object array "
+ "contains mixed native types: {}".format(
+ name, ", ".join(x.__name__ for x in native_dtypes)
+ )
+ )
+
+ native_dtypes = set(np.vectorize(type, otypes=[object])(array.ravel()))
+ if len(native_dtypes) > 1 and native_dtypes != {bytes, str}:
+ raise ValueError(
+ "unable to infer dtype on variable {!r}; object array "
+ "contains mixed native types: {}".format(
+ name, ", ".join(x.__name__ for x in native_dtypes)
+ )
+ )
+
element = array[(0,) * array.ndim]
# We use the base types to avoid subclasses of bytes and str (which might
# not play nice with e.g. hdf5 datatypes), such as those from numpy
@@ -265,6 +281,10 @@ def decode_cf_variable(
var = strings.CharacterArrayCoder().decode(var, name=name)
var = strings.EncodedStringCoder().decode(var)
+ if original_dtype == object:
+ var = variables.ObjectVLenStringCoder().decode(var)
+ original_dtype = var.dtype
+
if mask_and_scale:
for coder in [
variables.UnsignedIntegerCoder(),
diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py
index 732ec5d3ea6..28857c2d26e 100644
--- a/xarray/core/alignment.py
+++ b/xarray/core/alignment.py
@@ -324,7 +324,7 @@ def assert_no_index_conflict(self) -> None:
"- they may be used to reindex data along common dimensions"
)
- def _need_reindex(self, dims, cmp_indexes) -> bool:
+ def _need_reindex(self, dim, cmp_indexes) -> bool:
"""Whether or not we need to reindex variables for a set of
matching indexes.
@@ -340,14 +340,14 @@ def _need_reindex(self, dims, cmp_indexes) -> bool:
return True
unindexed_dims_sizes = {}
- for dim in dims:
- if dim in self.unindexed_dim_sizes:
- sizes = self.unindexed_dim_sizes[dim]
+ for d in dim:
+ if d in self.unindexed_dim_sizes:
+ sizes = self.unindexed_dim_sizes[d]
if len(sizes) > 1:
# reindex if different sizes are found for unindexed dims
return True
else:
- unindexed_dims_sizes[dim] = next(iter(sizes))
+ unindexed_dims_sizes[d] = next(iter(sizes))
if unindexed_dims_sizes:
indexed_dims_sizes = {}
@@ -356,8 +356,8 @@ def _need_reindex(self, dims, cmp_indexes) -> bool:
for var in index_vars.values():
indexed_dims_sizes.update(var.sizes)
- for dim, size in unindexed_dims_sizes.items():
- if indexed_dims_sizes.get(dim, -1) != size:
+ for d, size in unindexed_dims_sizes.items():
+ if indexed_dims_sizes.get(d, -1) != size:
# reindex if unindexed dimension size doesn't match
return True
@@ -681,7 +681,7 @@ def align(
...
-def align( # type: ignore[misc]
+def align(
*objects: T_Alignable,
join: JoinOptions = "inner",
copy: bool = True,
@@ -1153,7 +1153,7 @@ def broadcast(
...
-def broadcast( # type: ignore[misc]
+def broadcast(
*args: T_Alignable, exclude: str | Iterable[Hashable] | None = None
) -> tuple[T_Alignable, ...]:
"""Explicitly broadcast any number of DataArray or Dataset objects against
diff --git a/xarray/core/combine.py b/xarray/core/combine.py
index eecd01d011e..1939e2c7d0f 100644
--- a/xarray/core/combine.py
+++ b/xarray/core/combine.py
@@ -739,7 +739,7 @@ def combine_by_coords(
dimension must have the same size in all objects.
combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \
- "override"} or callable, default: "drop"
+ "override"} or callable, default: "no_conflicts"
A callable or a string indicating how to combine attrs of the objects being
merged:
diff --git a/xarray/core/common.py b/xarray/core/common.py
index fa0fa9aec0f..6dff9cc4024 100644
--- a/xarray/core/common.py
+++ b/xarray/core/common.py
@@ -21,6 +21,7 @@
emit_user_level_warning,
is_scalar,
)
+from xarray.namedarray.core import _raise_if_any_duplicate_dimensions
try:
import cftime
@@ -217,6 +218,7 @@ def get_axis_num(self, dim: Hashable | Iterable[Hashable]) -> int | tuple[int, .
return self._get_axis_num(dim)
def _get_axis_num(self: Any, dim: Hashable) -> int:
+ _raise_if_any_duplicate_dimensions(self.dims)
try:
return self.dims.index(dim)
except ValueError:
@@ -474,7 +476,7 @@ def _calc_assign_results(
def assign_coords(
self,
- coords: Mapping[Any, Any] | None = None,
+ coords: Mapping | None = None,
**coords_kwargs: Any,
) -> Self:
"""Assign new coordinates to this object.
@@ -484,15 +486,21 @@ def assign_coords(
Parameters
----------
- coords : dict-like or None, optional
- A dict where the keys are the names of the coordinates
- with the new values to assign. If the values are callable, they are
- computed on this object and assigned to new coordinate variables.
- If the values are not callable, (e.g. a ``DataArray``, scalar, or
- array), they are simply assigned. A new coordinate can also be
- defined and attached to an existing dimension using a tuple with
- the first element the dimension name and the second element the
- values for this new coordinate.
+ coords : mapping of dim to coord, optional
+ A mapping whose keys are the names of the coordinates and values are the
+ coordinates to assign. The mapping will generally be a dict or
+ :class:`Coordinates`.
+
+ * If a value is a standard data value — for example, a ``DataArray``,
+ scalar, or array — the data is simply assigned as a coordinate.
+
+ * If a value is callable, it is called with this object as the only
+ parameter, and the return value is used as new coordinate variables.
+
+ * A coordinate can also be defined and attached to an existing dimension
+ using a tuple with the first element the dimension name and the second
+ element the values for this new coordinate.
+
**coords_kwargs : optional
The keyword arguments form of ``coords``.
One of ``coords`` or ``coords_kwargs`` must be provided.
@@ -593,14 +601,6 @@ def assign_coords(
Attributes:
description: Weather-related data
- Notes
- -----
- Since ``coords_kwargs`` is a dictionary, the order of your arguments
- may not be preserved, and so the order of the new variables is not well
- defined. Assigning multiple variables within the same ``assign_coords``
- is possible, but you cannot reference other variables created within
- the same ``assign_coords`` call.
-
See Also
--------
Dataset.assign
@@ -1010,8 +1010,11 @@ def _resample(
if loffset is not None:
emit_user_level_warning(
- "Following pandas, the `loffset` parameter to resample will be deprecated "
- "in a future version of xarray. Switch to using time offset arithmetic.",
+ "Following pandas, the `loffset` parameter to resample is deprecated. "
+ "Switch to updating the resampled dataset time coordinate using "
+ "time offset arithmetic. For example:\n"
+ " >>> offset = pd.tseries.frequencies.to_offset(freq) / 2\n"
+ ' >>> resampled_ds["time"] = resampled_ds.get_index("time") + offset',
FutureWarning,
)
@@ -1164,7 +1167,7 @@ def _dataset_indexer(dim: Hashable) -> DataArray:
cond_wdim = cond.drop_vars(
var for var in cond if dim not in cond[var].dims
)
- keepany = cond_wdim.any(dim=(d for d in cond.dims.keys() if d != dim))
+ keepany = cond_wdim.any(dim=(d for d in cond.dims if d != dim))
return keepany.to_dataarray().any("variable")
_get_indexer = (
diff --git a/xarray/core/computation.py b/xarray/core/computation.py
index 0c5c9d6d5cb..c6c7ef97e42 100644
--- a/xarray/core/computation.py
+++ b/xarray/core/computation.py
@@ -9,7 +9,7 @@
import warnings
from collections import Counter
from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence, Set
-from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, Union, overload
+from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, Union, cast, overload
import numpy as np
@@ -26,6 +26,7 @@
from xarray.core.types import Dims, T_DataArray
from xarray.core.utils import is_dict_like, is_scalar
from xarray.core.variable import Variable
+from xarray.util.deprecation_helpers import deprecate_dims
if TYPE_CHECKING:
from xarray.core.coordinates import Coordinates
@@ -1280,7 +1281,11 @@ def apply_ufunc(
def cov(
- da_a: T_DataArray, da_b: T_DataArray, dim: Dims = None, ddof: int = 1
+ da_a: T_DataArray,
+ da_b: T_DataArray,
+ dim: Dims = None,
+ ddof: int = 1,
+ weights: T_DataArray | None = None,
) -> T_DataArray:
"""
Compute covariance between two DataArray objects along a shared dimension.
@@ -1296,6 +1301,8 @@ def cov(
ddof : int, default: 1
If ddof=1, covariance is normalized by N-1, giving an unbiased estimate,
else normalization is by N.
+ weights : DataArray, optional
+ Array of weights.
Returns
-------
@@ -1349,6 +1356,23 @@ def cov(
array([ 0.2 , -0.5 , 1.69333333])
Coordinates:
* space (space) >> weights = DataArray(
+ ... [4, 2, 1],
+ ... dims=("space"),
+ ... coords=[
+ ... ("space", ["IA", "IL", "IN"]),
+ ... ],
+ ... )
+ >>> weights
+
+ array([4, 2, 1])
+ Coordinates:
+ * space (space) >> xr.cov(da_a, da_b, dim="space", weights=weights)
+
+ array([-4.69346939, -4.49632653, -3.37959184])
+ Coordinates:
+ * time (time) datetime64[ns] 2000-01-01 2000-01-02 2000-01-03
"""
from xarray.core.dataarray import DataArray
@@ -1357,11 +1381,18 @@ def cov(
"Only xr.DataArray is supported."
f"Given {[type(arr) for arr in [da_a, da_b]]}."
)
-
- return _cov_corr(da_a, da_b, dim=dim, ddof=ddof, method="cov")
+ if weights is not None:
+ if not isinstance(weights, DataArray):
+ raise TypeError("Only xr.DataArray is supported." f"Given {type(weights)}.")
+ return _cov_corr(da_a, da_b, weights=weights, dim=dim, ddof=ddof, method="cov")
-def corr(da_a: T_DataArray, da_b: T_DataArray, dim: Dims = None) -> T_DataArray:
+def corr(
+ da_a: T_DataArray,
+ da_b: T_DataArray,
+ dim: Dims = None,
+ weights: T_DataArray | None = None,
+) -> T_DataArray:
"""
Compute the Pearson correlation coefficient between
two DataArray objects along a shared dimension.
@@ -1374,6 +1405,8 @@ def corr(da_a: T_DataArray, da_b: T_DataArray, dim: Dims = None) -> T_DataArray:
Array to compute.
dim : str, iterable of hashable, "..." or None, optional
The dimension along which the correlation will be computed
+ weights : DataArray, optional
+ Array of weights.
Returns
-------
@@ -1427,6 +1460,23 @@ def corr(da_a: T_DataArray, da_b: T_DataArray, dim: Dims = None) -> T_DataArray:
array([ 1., -1., 1.])
Coordinates:
* space (space) >> weights = DataArray(
+ ... [4, 2, 1],
+ ... dims=("space"),
+ ... coords=[
+ ... ("space", ["IA", "IL", "IN"]),
+ ... ],
+ ... )
+ >>> weights
+
+ array([4, 2, 1])
+ Coordinates:
+ * space (space) >> xr.corr(da_a, da_b, dim="space", weights=weights)
+
+ array([-0.50240504, -0.83215028, -0.99057446])
+ Coordinates:
+ * time (time) datetime64[ns] 2000-01-01 2000-01-02 2000-01-03
"""
from xarray.core.dataarray import DataArray
@@ -1435,13 +1485,16 @@ def corr(da_a: T_DataArray, da_b: T_DataArray, dim: Dims = None) -> T_DataArray:
"Only xr.DataArray is supported."
f"Given {[type(arr) for arr in [da_a, da_b]]}."
)
-
- return _cov_corr(da_a, da_b, dim=dim, method="corr")
+ if weights is not None:
+ if not isinstance(weights, DataArray):
+ raise TypeError("Only xr.DataArray is supported." f"Given {type(weights)}.")
+ return _cov_corr(da_a, da_b, weights=weights, dim=dim, method="corr")
def _cov_corr(
da_a: T_DataArray,
da_b: T_DataArray,
+ weights: T_DataArray | None = None,
dim: Dims = None,
ddof: int = 0,
method: Literal["cov", "corr", None] = None,
@@ -1457,28 +1510,46 @@ def _cov_corr(
valid_values = da_a.notnull() & da_b.notnull()
da_a = da_a.where(valid_values)
da_b = da_b.where(valid_values)
- valid_count = valid_values.sum(dim) - ddof
# 3. Detrend along the given dim
- demeaned_da_a = da_a - da_a.mean(dim=dim)
- demeaned_da_b = da_b - da_b.mean(dim=dim)
+ if weights is not None:
+ demeaned_da_a = da_a - da_a.weighted(weights).mean(dim=dim)
+ demeaned_da_b = da_b - da_b.weighted(weights).mean(dim=dim)
+ else:
+ demeaned_da_a = da_a - da_a.mean(dim=dim)
+ demeaned_da_b = da_b - da_b.mean(dim=dim)
# 4. Compute covariance along the given dim
# N.B. `skipna=True` is required or auto-covariance is computed incorrectly. E.g.
# Try xr.cov(da,da) for da = xr.DataArray([[1, 2], [1, np.nan]], dims=["x", "time"])
- cov = (demeaned_da_a.conj() * demeaned_da_b).sum(
- dim=dim, skipna=True, min_count=1
- ) / (valid_count)
+ if weights is not None:
+ cov = (
+ (demeaned_da_a.conj() * demeaned_da_b)
+ .weighted(weights)
+ .mean(dim=dim, skipna=True)
+ )
+ else:
+ cov = (demeaned_da_a.conj() * demeaned_da_b).mean(dim=dim, skipna=True)
if method == "cov":
- return cov
+ # Adjust covariance for degrees of freedom
+ valid_count = valid_values.sum(dim)
+ adjust = valid_count / (valid_count - ddof)
+ # I think the cast is required because of `T_DataArray` + `T_Xarray` (would be
+ # the same with `T_DatasetOrArray`)
+ # https://github.com/pydata/xarray/pull/8384#issuecomment-1784228026
+ return cast(T_DataArray, cov * adjust)
else:
- # compute std + corr
- da_a_std = da_a.std(dim=dim)
- da_b_std = da_b.std(dim=dim)
+ # Compute std and corr
+ if weights is not None:
+ da_a_std = da_a.weighted(weights).std(dim=dim)
+ da_b_std = da_b.weighted(weights).std(dim=dim)
+ else:
+ da_a_std = da_a.std(dim=dim)
+ da_b_std = da_b.std(dim=dim)
corr = cov / (da_a_std * da_b_std)
- return corr
+ return cast(T_DataArray, corr)
def cross(
@@ -1691,9 +1762,10 @@ def cross(
return c
+@deprecate_dims
def dot(
*arrays,
- dims: Dims = None,
+ dim: Dims = None,
**kwargs: Any,
):
"""Generalized dot product for xarray objects. Like ``np.einsum``, but
@@ -1703,7 +1775,7 @@ def dot(
----------
*arrays : DataArray or Variable
Arrays to compute.
- dims : str, iterable of hashable, "..." or None, optional
+ dim : str, iterable of hashable, "..." or None, optional
Which dimensions to sum over. Ellipsis ('...') sums over all dimensions.
If not specified, then all the common dimensions are summed over.
**kwargs : dict
@@ -1756,18 +1828,18 @@ def dot(
[3, 4, 5]])
Dimensions without coordinates: c, d
- >>> xr.dot(da_a, da_b, dims=["a", "b"])
+ >>> xr.dot(da_a, da_b, dim=["a", "b"])
array([110, 125])
Dimensions without coordinates: c
- >>> xr.dot(da_a, da_b, dims=["a"])
+ >>> xr.dot(da_a, da_b, dim=["a"])
array([[40, 46],
[70, 79]])
Dimensions without coordinates: b, c
- >>> xr.dot(da_a, da_b, da_c, dims=["b", "c"])
+ >>> xr.dot(da_a, da_b, da_c, dim=["b", "c"])
array([[ 9, 14, 19],
[ 93, 150, 207],
@@ -1779,7 +1851,7 @@ def dot(
array([110, 125])
Dimensions without coordinates: c
- >>> xr.dot(da_a, da_b, dims=...)
+ >>> xr.dot(da_a, da_b, dim=...)
array(235)
"""
@@ -1803,18 +1875,18 @@ def dot(
einsum_axes = "abcdefghijklmnopqrstuvwxyz"
dim_map = {d: einsum_axes[i] for i, d in enumerate(all_dims)}
- if dims is ...:
- dims = all_dims
- elif isinstance(dims, str):
- dims = (dims,)
- elif dims is None:
+ if dim is ...:
+ dim = all_dims
+ elif isinstance(dim, str):
+ dim = (dim,)
+ elif dim is None:
# find dimensions that occur more than one times
dim_counts: Counter = Counter()
for arr in arrays:
dim_counts.update(arr.dims)
- dims = tuple(d for d, c in dim_counts.items() if c > 1)
+ dim = tuple(d for d, c in dim_counts.items() if c > 1)
- dot_dims: set[Hashable] = set(dims)
+ dot_dims: set[Hashable] = set(dim)
# dimensions to be parallelized
broadcast_dims = common_dims - dot_dims
diff --git a/xarray/core/concat.py b/xarray/core/concat.py
index 8c558b38848..26cf36b3b07 100644
--- a/xarray/core/concat.py
+++ b/xarray/core/concat.py
@@ -315,7 +315,7 @@ def _calc_concat_over(datasets, dim, dim_names, data_vars: T_DataVars, coords, c
if dim in ds:
ds = ds.set_coords(dim)
concat_over.update(k for k, v in ds.variables.items() if dim in v.dims)
- concat_dim_lengths.append(ds.dims.get(dim, 1))
+ concat_dim_lengths.append(ds.sizes.get(dim, 1))
def process_subset_opt(opt, subset):
if isinstance(opt, str):
@@ -431,7 +431,7 @@ def _parse_datasets(
variables_order: dict[Hashable, Variable] = {} # variables in order of appearance
for ds in datasets:
- dims_sizes.update(ds.dims)
+ dims_sizes.update(ds.sizes)
all_coord_names.update(ds.coords)
data_vars.update(ds.data_vars)
variables_order.update(ds.variables)
@@ -536,9 +536,10 @@ def _dataset_concat(
result_encoding = datasets[0].encoding
# check that global attributes are fixed across all datasets if necessary
- for ds in datasets[1:]:
- if compat == "identical" and not utils.dict_equiv(ds.attrs, result_attrs):
- raise ValueError("Dataset global attributes not equal.")
+ if compat == "identical":
+ for ds in datasets[1:]:
+ if not utils.dict_equiv(ds.attrs, result_attrs):
+ raise ValueError("Dataset global attributes not equal.")
# we've already verified everything is consistent; now, calculate
# shared dimension sizes so we can expand the necessary variables
diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py
index 0c85b2a2d69..c59c5deba16 100644
--- a/xarray/core/coordinates.py
+++ b/xarray/core/coordinates.py
@@ -213,7 +213,7 @@ class Coordinates(AbstractCoordinates):
:py:class:`~xarray.Coordinates` object is passed, its indexes
will be added to the new created object.
indexes: dict-like, optional
- Mapping of where keys are coordinate names and values are
+ Mapping where keys are coordinate names and values are
:py:class:`~xarray.indexes.Index` objects. If None (default),
pandas indexes will be created for each dimension coordinate.
Passing an empty dictionary will skip this default behavior.
@@ -571,11 +571,18 @@ def assign(self, coords: Mapping | None = None, **coords_kwargs: Any) -> Self:
Parameters
----------
- coords : :class:`Coordinates` or mapping of hashable to Any
- Mapping from coordinate names to the new values. If a ``Coordinates``
- object is passed, its indexes are assigned in the returned object.
- Otherwise, a default (pandas) index is created for each dimension
- coordinate found in the mapping.
+ coords : mapping of dim to coord, optional
+ A mapping whose keys are the names of the coordinates and values are the
+ coordinates to assign. The mapping will generally be a dict or
+ :class:`Coordinates`.
+
+ * If a value is a standard data value — for example, a ``DataArray``,
+ scalar, or array — the data is simply assigned as a coordinate.
+
+ * A coordinate can also be defined and attached to an existing dimension
+ using a tuple with the first element the dimension name and the second
+ element the values for this new coordinate.
+
**coords_kwargs
The keyword arguments form of ``coords``.
One of ``coords`` or ``coords_kwargs`` must be provided.
@@ -605,6 +612,7 @@ def assign(self, coords: Mapping | None = None, **coords_kwargs: Any) -> Self:
* y_level_1 (y) int64 0 1 0 1
"""
+ # TODO: this doesn't support a callable, which is inconsistent with `DataArray.assign_coords`
coords = either_dict_or_kwargs(coords, coords_kwargs, "assign")
new_coords = self.copy()
new_coords.update(coords)
diff --git a/xarray/core/dask_array_ops.py b/xarray/core/dask_array_ops.py
index d2d3e4a6d1c..98ff9002856 100644
--- a/xarray/core/dask_array_ops.py
+++ b/xarray/core/dask_array_ops.py
@@ -59,10 +59,11 @@ def push(array, n, axis):
"""
Dask-aware bottleneck.push
"""
- import bottleneck
import dask.array as da
import numpy as np
+ from xarray.core.duck_array_ops import _push
+
def _fill_with_last_one(a, b):
# cumreduction apply the push func over all the blocks first so, the only missing part is filling
# the missing values using the last data of the previous chunk
@@ -85,7 +86,7 @@ def _fill_with_last_one(a, b):
# The method parameter makes that the tests for python 3.7 fails.
return da.reductions.cumreduction(
- func=bottleneck.push,
+ func=_push,
binop=_fill_with_last_one,
ident=np.nan,
x=array,
diff --git a/xarray/core/daskmanager.py b/xarray/core/daskmanager.py
index 56d8dc9e23a..efa04bc3df2 100644
--- a/xarray/core/daskmanager.py
+++ b/xarray/core/daskmanager.py
@@ -97,6 +97,28 @@ def reduction(
keepdims=keepdims,
)
+ def scan(
+ self,
+ func: Callable,
+ binop: Callable,
+ ident: float,
+ arr: T_ChunkedArray,
+ axis: int | None = None,
+ dtype: np.dtype | None = None,
+ **kwargs,
+ ) -> DaskArray:
+ from dask.array.reductions import cumreduction
+
+ return cumreduction(
+ func,
+ binop,
+ ident,
+ arr,
+ axis=axis,
+ dtype=dtype,
+ **kwargs,
+ )
+
def apply_gufunc(
self,
func: Callable,
diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py
index b417470fdc0..0f245ff464b 100644
--- a/xarray/core/dataarray.py
+++ b/xarray/core/dataarray.py
@@ -49,7 +49,12 @@
from xarray.core.indexing import is_fancy_indexer, map_index_queries
from xarray.core.merge import PANDAS_TYPES, MergeError
from xarray.core.options import OPTIONS, _get_keep_attrs
-from xarray.core.types import DaCompatible, T_DataArray, T_DataArrayOrSet
+from xarray.core.types import (
+ DaCompatible,
+ T_DataArray,
+ T_DataArrayOrSet,
+ ZarrWriteModes,
+)
from xarray.core.utils import (
Default,
HybridMappingProxy,
@@ -65,7 +70,7 @@
)
from xarray.plot.accessor import DataArrayPlotAccessor
from xarray.plot.utils import _get_units_from_attrs
-from xarray.util.deprecation_helpers import _deprecate_positional_args
+from xarray.util.deprecation_helpers import _deprecate_positional_args, deprecate_dims
if TYPE_CHECKING:
from typing import TypeVar, Union
@@ -75,7 +80,7 @@
try:
from dask.dataframe import DataFrame as DaskDataFrame
except ImportError:
- DaskDataFrame = None # type: ignore
+ DaskDataFrame = None
try:
from dask.delayed import Delayed
except ImportError:
@@ -115,14 +120,14 @@
T_XarrayOther = TypeVar("T_XarrayOther", bound=Union["DataArray", Dataset])
-def _check_coords_dims(shape, coords, dims):
- sizes = dict(zip(dims, shape))
+def _check_coords_dims(shape, coords, dim):
+ sizes = dict(zip(dim, shape))
for k, v in coords.items():
- if any(d not in dims for d in v.dims):
+ if any(d not in dim for d in v.dims):
raise ValueError(
f"coordinate {k} has dimensions {v.dims}, but these "
"are not a subset of the DataArray "
- f"dimensions {dims}"
+ f"dimensions {dim}"
)
for d, s in v.sizes.items():
@@ -3036,7 +3041,7 @@ def T(self) -> Self:
def drop_vars(
self,
- names: Hashable | Iterable[Hashable],
+ names: str | Iterable[Hashable] | Callable[[Self], str | Iterable[Hashable]],
*,
errors: ErrorOptions = "raise",
) -> Self:
@@ -3044,8 +3049,9 @@ def drop_vars(
Parameters
----------
- names : Hashable or iterable of Hashable
- Name(s) of variables to drop.
+ names : Hashable or iterable of Hashable or Callable
+ Name(s) of variables to drop. If a Callable, this object is passed as its
+ only argument and its result is used.
errors : {"raise", "ignore"}, default: "raise"
If 'raise', raises a ValueError error if any of the variable
passed are not in the dataset. If 'ignore', any given names that are in the
@@ -3095,7 +3101,17 @@ def drop_vars(
[ 6, 7, 8],
[ 9, 10, 11]])
Dimensions without coordinates: x, y
+
+ >>> da.drop_vars(lambda x: x.coords)
+
+ array([[ 0, 1, 2],
+ [ 3, 4, 5],
+ [ 6, 7, 8],
+ [ 9, 10, 11]])
+ Dimensions without coordinates: x, y
"""
+ if callable(names):
+ names = names(self)
ds = self._to_temp_dataset().drop_vars(names, errors=errors)
return self._from_temp_dataset(ds)
@@ -3909,6 +3925,23 @@ def to_netcdf(
) -> bytes:
...
+ # compute=False returns dask.Delayed
+ @overload
+ def to_netcdf(
+ self,
+ path: str | PathLike,
+ mode: Literal["w", "a"] = "w",
+ format: T_NetcdfTypes | None = None,
+ group: str | None = None,
+ engine: T_NetcdfEngine | None = None,
+ encoding: Mapping[Hashable, Mapping[str, Any]] | None = None,
+ unlimited_dims: Iterable[Hashable] | None = None,
+ *,
+ compute: Literal[False],
+ invalid_netcdf: bool = False,
+ ) -> Delayed:
+ ...
+
# default return None
@overload
def to_netcdf(
@@ -3925,7 +3958,8 @@ def to_netcdf(
) -> None:
...
- # compute=False returns dask.Delayed
+ # if compute cannot be evaluated at type check time
+ # we may get back either Delayed or None
@overload
def to_netcdf(
self,
@@ -3936,10 +3970,9 @@ def to_netcdf(
engine: T_NetcdfEngine | None = None,
encoding: Mapping[Hashable, Mapping[str, Any]] | None = None,
unlimited_dims: Iterable[Hashable] | None = None,
- *,
- compute: Literal[False],
+ compute: bool = True,
invalid_netcdf: bool = False,
- ) -> Delayed:
+ ) -> Delayed | None:
...
def to_netcdf(
@@ -4074,7 +4107,7 @@ def to_zarr(
self,
store: MutableMapping | str | PathLike[str] | None = None,
chunk_store: MutableMapping | str | PathLike | None = None,
- mode: Literal["w", "w-", "a", "r+", None] = None,
+ mode: ZarrWriteModes | None = None,
synchronizer=None,
group: str | None = None,
*,
@@ -4095,7 +4128,7 @@ def to_zarr(
self,
store: MutableMapping | str | PathLike[str] | None = None,
chunk_store: MutableMapping | str | PathLike | None = None,
- mode: Literal["w", "w-", "a", "r+", None] = None,
+ mode: ZarrWriteModes | None = None,
synchronizer=None,
group: str | None = None,
encoding: Mapping | None = None,
@@ -4114,7 +4147,7 @@ def to_zarr(
self,
store: MutableMapping | str | PathLike[str] | None = None,
chunk_store: MutableMapping | str | PathLike | None = None,
- mode: Literal["w", "w-", "a", "r+", None] = None,
+ mode: ZarrWriteModes | None = None,
synchronizer=None,
group: str | None = None,
encoding: Mapping | None = None,
@@ -4150,10 +4183,11 @@ def to_zarr(
chunk_store : MutableMapping, str or path-like, optional
Store or path to directory in local or remote file system only for Zarr
array chunks. Requires zarr-python v2.4.0 or later.
- mode : {"w", "w-", "a", "r+", None}, optional
+ mode : {"w", "w-", "a", "a-", r+", None}, optional
Persistence mode: "w" means create (overwrite if exists);
"w-" means create (fail if exists);
- "a" means override existing variables (create if does not exist);
+ "a" means override all existing variables including dimension coordinates (create if does not exist);
+ "a-" means only append those variables that have ``append_dim``.
"r+" means modify existing array *values* only (raise an error if
any metadata or shapes would change).
The default mode is "a" if ``append_dim`` is set. Otherwise, it is
@@ -4895,10 +4929,11 @@ def imag(self) -> Self:
"""
return self._replace(self.variable.imag)
+ @deprecate_dims
def dot(
self,
other: T_Xarray,
- dims: Dims = None,
+ dim: Dims = None,
) -> T_Xarray:
"""Perform dot product of two DataArrays along their shared dims.
@@ -4908,7 +4943,7 @@ def dot(
----------
other : DataArray
The other array with which the dot product is performed.
- dims : ..., str, Iterable of Hashable or None, optional
+ dim : ..., str, Iterable of Hashable or None, optional
Which dimensions to sum over. Ellipsis (`...`) sums over all dimensions.
If not specified, then all the common dimensions are summed over.
@@ -4947,7 +4982,7 @@ def dot(
if not isinstance(other, DataArray):
raise TypeError("dot only operates on DataArrays.")
- return computation.dot(self, other, dims=dims)
+ return computation.dot(self, other, dim=dim)
def sortby(
self,
@@ -6213,8 +6248,8 @@ def curvefit(
func: Callable[..., Any],
reduce_dims: Dims = None,
skipna: bool = True,
- p0: dict[str, float | DataArray] | None = None,
- bounds: dict[str, tuple[float | DataArray, float | DataArray]] | None = None,
+ p0: Mapping[str, float | DataArray] | None = None,
+ bounds: Mapping[str, tuple[float | DataArray, float | DataArray]] | None = None,
param_names: Sequence[str] | None = None,
errors: ErrorOptions = "raise",
kwargs: dict[str, Any] | None = None,
@@ -6888,14 +6923,90 @@ def rolling(
See Also
--------
- core.rolling.DataArrayRolling
+ DataArray.cumulative
Dataset.rolling
+ core.rolling.DataArrayRolling
"""
from xarray.core.rolling import DataArrayRolling
dim = either_dict_or_kwargs(dim, window_kwargs, "rolling")
return DataArrayRolling(self, dim, min_periods=min_periods, center=center)
+ def cumulative(
+ self,
+ dim: str | Iterable[Hashable],
+ min_periods: int = 1,
+ ) -> DataArrayRolling:
+ """
+ Accumulating object for DataArrays.
+
+ Parameters
+ ----------
+ dims : iterable of hashable
+ The name(s) of the dimensions to create the cumulative window along
+ min_periods : int, default: 1
+ Minimum number of observations in window required to have a value
+ (otherwise result is NA). The default is 1 (note this is different
+ from ``Rolling``, whose default is the size of the window).
+
+ Returns
+ -------
+ core.rolling.DataArrayRolling
+
+ Examples
+ --------
+ Create rolling seasonal average of monthly data e.g. DJF, JFM, ..., SON:
+
+ >>> da = xr.DataArray(
+ ... np.linspace(0, 11, num=12),
+ ... coords=[
+ ... pd.date_range(
+ ... "1999-12-15",
+ ... periods=12,
+ ... freq=pd.DateOffset(months=1),
+ ... )
+ ... ],
+ ... dims="time",
+ ... )
+
+ >>> da
+
+ array([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.])
+ Coordinates:
+ * time (time) datetime64[ns] 1999-12-15 2000-01-15 ... 2000-11-15
+
+ >>> da.cumulative("time").sum()
+
+ array([ 0., 1., 3., 6., 10., 15., 21., 28., 36., 45., 55., 66.])
+ Coordinates:
+ * time (time) datetime64[ns] 1999-12-15 2000-01-15 ... 2000-11-15
+
+ See Also
+ --------
+ DataArray.rolling
+ Dataset.cumulative
+ core.rolling.DataArrayRolling
+ """
+ from xarray.core.rolling import DataArrayRolling
+
+ # Could we abstract this "normalize and check 'dim'" logic? It's currently shared
+ # with the same method in Dataset.
+ if isinstance(dim, str):
+ if dim not in self.dims:
+ raise ValueError(
+ f"Dimension {dim} not found in data dimensions: {self.dims}"
+ )
+ dim = {dim: self.sizes[dim]}
+ else:
+ missing_dims = set(dim) - set(self.dims)
+ if missing_dims:
+ raise ValueError(
+ f"Dimensions {missing_dims} not found in data dimensions: {self.dims}"
+ )
+ dim = {d: self.sizes[d] for d in dim}
+
+ return DataArrayRolling(self, dim, min_periods=min_periods, center=False)
+
def coarsen(
self,
dim: Mapping[Any, int] | None = None,
@@ -7031,6 +7142,12 @@ def resample(
loffset : timedelta or str, optional
Offset used to adjust the resampled time labels. Some pandas date
offset strings are supported.
+
+ .. deprecated:: 2023.03.0
+ Following pandas, the ``loffset`` parameter is deprecated in favor
+ of using time offset arithmetic, and will be removed in a future
+ version of xarray.
+
restore_coord_dims : bool, optional
If True, also restore the dimension order of multi-dimensional
coordinates.
diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py
index c8e7564d3ca..a6fc0e2ca18 100644
--- a/xarray/core/dataset.py
+++ b/xarray/core/dataset.py
@@ -98,18 +98,22 @@
Self,
T_ChunkDim,
T_Chunks,
+ T_DataArray,
T_DataArrayOrSet,
T_Dataset,
+ ZarrWriteModes,
)
from xarray.core.utils import (
Default,
Frozen,
+ FrozenMappingWarningOnValuesAccess,
HybridMappingProxy,
OrderedSet,
_default,
decode_numpy_dict_values,
drop_dims_from_indexers,
either_dict_or_kwargs,
+ emit_user_level_warning,
infix_dims,
is_dict_like,
is_scalar,
@@ -167,7 +171,7 @@
try:
from dask.dataframe import DataFrame as DaskDataFrame
except ImportError:
- DaskDataFrame = None # type: ignore
+ DaskDataFrame = None
# list of attributes of pd.DatetimeIndex that are ndarrays of time info
@@ -776,14 +780,15 @@ def dims(self) -> Frozen[Hashable, int]:
Note that type of this object differs from `DataArray.dims`.
See `Dataset.sizes` and `DataArray.sizes` for consistently named
- properties.
+ properties. This property will be changed to return a type more consistent with
+ `DataArray.dims` in the future, i.e. a set of dimension names.
See Also
--------
Dataset.sizes
DataArray.dims
"""
- return Frozen(self._dims)
+ return FrozenMappingWarningOnValuesAccess(self._dims)
@property
def sizes(self) -> Frozen[Hashable, int]:
@@ -798,7 +803,7 @@ def sizes(self) -> Frozen[Hashable, int]:
--------
DataArray.sizes
"""
- return self.dims
+ return Frozen(self._dims)
@property
def dtypes(self) -> Frozen[Hashable, np.dtype]:
@@ -1409,7 +1414,7 @@ def _copy_listed(self, names: Iterable[Hashable]) -> Self:
variables[name] = self._variables[name]
except KeyError:
ref_name, var_name, var = _get_virtual_variable(
- self._variables, name, self.dims
+ self._variables, name, self.sizes
)
variables[var_name] = var
if ref_name in self._coord_names or ref_name in self.dims:
@@ -1424,7 +1429,7 @@ def _copy_listed(self, names: Iterable[Hashable]) -> Self:
for v in variables.values():
needed_dims.update(v.dims)
- dims = {k: self.dims[k] for k in needed_dims}
+ dims = {k: self.sizes[k] for k in needed_dims}
# preserves ordering of coordinates
for k in self._variables:
@@ -1446,7 +1451,7 @@ def _construct_dataarray(self, name: Hashable) -> DataArray:
try:
variable = self._variables[name]
except KeyError:
- _, name, variable = _get_virtual_variable(self._variables, name, self.dims)
+ _, name, variable = _get_virtual_variable(self._variables, name, self.sizes)
needed_dims = set(variable.dims)
@@ -1473,7 +1478,7 @@ def _item_sources(self) -> Iterable[Mapping[Hashable, Any]]:
yield HybridMappingProxy(keys=self._coord_names, mapping=self.coords)
# virtual coordinates
- yield HybridMappingProxy(keys=self.dims, mapping=self)
+ yield HybridMappingProxy(keys=self.sizes, mapping=self)
def __contains__(self, key: object) -> bool:
"""The 'in' operator will return true or false depending on whether
@@ -1539,10 +1544,18 @@ def __getitem__(
Indexing with a list of names will return a new ``Dataset`` object.
"""
+ from xarray.core.formatting import shorten_list_repr
+
if utils.is_dict_like(key):
return self.isel(**key)
if utils.hashable(key):
- return self._construct_dataarray(key)
+ try:
+ return self._construct_dataarray(key)
+ except KeyError as e:
+ raise KeyError(
+ f"No variable named {key!r}. Variables on the dataset include {shorten_list_repr(list(self.variables.keys()), max_items=10)}"
+ ) from e
+
if utils.iterable_of_hashable(key):
return self._copy_listed(key)
raise ValueError(f"Unsupported key-type {type(key)}")
@@ -2149,6 +2162,23 @@ def to_netcdf(
) -> bytes:
...
+ # compute=False returns dask.Delayed
+ @overload
+ def to_netcdf(
+ self,
+ path: str | PathLike,
+ mode: Literal["w", "a"] = "w",
+ format: T_NetcdfTypes | None = None,
+ group: str | None = None,
+ engine: T_NetcdfEngine | None = None,
+ encoding: Mapping[Any, Mapping[str, Any]] | None = None,
+ unlimited_dims: Iterable[Hashable] | None = None,
+ *,
+ compute: Literal[False],
+ invalid_netcdf: bool = False,
+ ) -> Delayed:
+ ...
+
# default return None
@overload
def to_netcdf(
@@ -2165,7 +2195,8 @@ def to_netcdf(
) -> None:
...
- # compute=False returns dask.Delayed
+ # if compute cannot be evaluated at type check time
+ # we may get back either Delayed or None
@overload
def to_netcdf(
self,
@@ -2176,10 +2207,9 @@ def to_netcdf(
engine: T_NetcdfEngine | None = None,
encoding: Mapping[Any, Mapping[str, Any]] | None = None,
unlimited_dims: Iterable[Hashable] | None = None,
- *,
- compute: Literal[False],
+ compute: bool = True,
invalid_netcdf: bool = False,
- ) -> Delayed:
+ ) -> Delayed | None:
...
def to_netcdf(
@@ -2297,7 +2327,7 @@ def to_zarr(
self,
store: MutableMapping | str | PathLike[str] | None = None,
chunk_store: MutableMapping | str | PathLike | None = None,
- mode: Literal["w", "w-", "a", "r+", None] = None,
+ mode: ZarrWriteModes | None = None,
synchronizer=None,
group: str | None = None,
encoding: Mapping | None = None,
@@ -2320,7 +2350,7 @@ def to_zarr(
self,
store: MutableMapping | str | PathLike[str] | None = None,
chunk_store: MutableMapping | str | PathLike | None = None,
- mode: Literal["w", "w-", "a", "r+", None] = None,
+ mode: ZarrWriteModes | None = None,
synchronizer=None,
group: str | None = None,
encoding: Mapping | None = None,
@@ -2341,7 +2371,7 @@ def to_zarr(
self,
store: MutableMapping | str | PathLike[str] | None = None,
chunk_store: MutableMapping | str | PathLike | None = None,
- mode: Literal["w", "w-", "a", "r+", None] = None,
+ mode: ZarrWriteModes | None = None,
synchronizer=None,
group: str | None = None,
encoding: Mapping | None = None,
@@ -2379,10 +2409,11 @@ def to_zarr(
chunk_store : MutableMapping, str or path-like, optional
Store or path to directory in local or remote file system only for Zarr
array chunks. Requires zarr-python v2.4.0 or later.
- mode : {"w", "w-", "a", "r+", None}, optional
+ mode : {"w", "w-", "a", "a-", r+", None}, optional
Persistence mode: "w" means create (overwrite if exists);
"w-" means create (fail if exists);
- "a" means override existing variables (create if does not exist);
+ "a" means override all existing variables including dimension coordinates (create if does not exist);
+ "a-" means only append those variables that have ``append_dim``.
"r+" means modify existing array *values* only (raise an error if
any metadata or shapes would change).
The default mode is "a" if ``append_dim`` is set. Otherwise, it is
@@ -2541,7 +2572,7 @@ def info(self, buf: IO | None = None) -> None:
lines = []
lines.append("xarray.Dataset {")
lines.append("dimensions:")
- for name, size in self.dims.items():
+ for name, size in self.sizes.items():
lines.append(f"\t{name} = {size} ;")
lines.append("\nvariables:")
for name, da in self.variables.items():
@@ -2669,10 +2700,10 @@ def chunk(
else:
chunks_mapping = either_dict_or_kwargs(chunks, chunks_kwargs, "chunk")
- bad_dims = chunks_mapping.keys() - self.dims.keys()
+ bad_dims = chunks_mapping.keys() - self.sizes.keys()
if bad_dims:
raise ValueError(
- f"chunks keys {tuple(bad_dims)} not found in data dimensions {tuple(self.dims)}"
+ f"chunks keys {tuple(bad_dims)} not found in data dimensions {tuple(self.sizes.keys())}"
)
chunkmanager = guess_chunkmanager(chunked_array_type)
@@ -3924,7 +3955,7 @@ def maybe_variable(obj, k):
try:
return obj._variables[k]
except KeyError:
- return as_variable((k, range(obj.dims[k])))
+ return as_variable((k, range(obj.sizes[k])))
def _validate_interp_indexer(x, new_x):
# In the case of datetimes, the restrictions placed on indexers
@@ -4148,7 +4179,7 @@ def _rename_vars(
return variables, coord_names
def _rename_dims(self, name_dict: Mapping[Any, Hashable]) -> dict[Hashable, int]:
- return {name_dict.get(k, k): v for k, v in self.dims.items()}
+ return {name_dict.get(k, k): v for k, v in self.sizes.items()}
def _rename_indexes(
self, name_dict: Mapping[Any, Hashable], dims_dict: Mapping[Any, Hashable]
@@ -5140,7 +5171,7 @@ def _get_stack_index(
if dim in self._variables:
var = self._variables[dim]
else:
- _, _, var = _get_virtual_variable(self._variables, dim, self.dims)
+ _, _, var = _get_virtual_variable(self._variables, dim, self.sizes)
# dummy index (only `stack_coords` will be used to construct the multi-index)
stack_index = PandasIndex([0], dim)
stack_coords = {dim: var}
@@ -5167,7 +5198,7 @@ def _stack_once(
if any(d in var.dims for d in dims):
add_dims = [d for d in dims if d not in var.dims]
vdims = list(var.dims) + add_dims
- shape = [self.dims[d] for d in vdims]
+ shape = [self.sizes[d] for d in vdims]
exp_var = var.set_dims(vdims, shape)
stacked_var = exp_var.stack(**{new_dim: dims})
new_variables[name] = stacked_var
@@ -5713,7 +5744,7 @@ def _assert_all_in_dataset(
def drop_vars(
self,
- names: Hashable | Iterable[Hashable],
+ names: str | Iterable[Hashable] | Callable[[Self], str | Iterable[Hashable]],
*,
errors: ErrorOptions = "raise",
) -> Self:
@@ -5721,8 +5752,9 @@ def drop_vars(
Parameters
----------
- names : hashable or iterable of hashable
- Name(s) of variables to drop.
+ names : Hashable or iterable of Hashable or Callable
+ Name(s) of variables to drop. If a Callable, this object is passed as its
+ only argument and its result is used.
errors : {"raise", "ignore"}, default: "raise"
If 'raise', raises a ValueError error if any of the variable
passed are not in the dataset. If 'ignore', any given names that are in the
@@ -5764,7 +5796,7 @@ def drop_vars(
humidity (time, latitude, longitude) float64 65.0 63.8 58.2 59.6
wind_speed (time, latitude, longitude) float64 10.2 8.5 12.1 9.8
- # Drop the 'humidity' variable
+ Drop the 'humidity' variable
>>> dataset.drop_vars(["humidity"])
@@ -5777,7 +5809,7 @@ def drop_vars(
temperature (time, latitude, longitude) float64 25.5 26.3 27.1 28.0
wind_speed (time, latitude, longitude) float64 10.2 8.5 12.1 9.8
- # Drop the 'humidity', 'temperature' variables
+ Drop the 'humidity', 'temperature' variables
>>> dataset.drop_vars(["humidity", "temperature"])
@@ -5789,7 +5821,18 @@ def drop_vars(
Data variables:
wind_speed (time, latitude, longitude) float64 10.2 8.5 12.1 9.8
- # Attempt to drop non-existent variable with errors="ignore"
+ Drop all indexes
+
+ >>> dataset.drop_vars(lambda x: x.indexes)
+
+ Dimensions: (time: 1, latitude: 2, longitude: 2)
+ Dimensions without coordinates: time, latitude, longitude
+ Data variables:
+ temperature (time, latitude, longitude) float64 25.5 26.3 27.1 28.0
+ humidity (time, latitude, longitude) float64 65.0 63.8 58.2 59.6
+ wind_speed (time, latitude, longitude) float64 10.2 8.5 12.1 9.8
+
+ Attempt to drop non-existent variable with errors="ignore"
>>> dataset.drop_vars(["pressure"], errors="ignore")
@@ -5803,7 +5846,7 @@ def drop_vars(
humidity (time, latitude, longitude) float64 65.0 63.8 58.2 59.6
wind_speed (time, latitude, longitude) float64 10.2 8.5 12.1 9.8
- # Attempt to drop non-existent variable with errors="raise"
+ Attempt to drop non-existent variable with errors="raise"
>>> dataset.drop_vars(["pressure"], errors="raise")
Traceback (most recent call last):
@@ -5823,24 +5866,26 @@ def drop_vars(
DataArray.drop_vars
"""
+ if callable(names):
+ names = names(self)
# the Iterable check is required for mypy
if is_scalar(names) or not isinstance(names, Iterable):
- names = {names}
+ names_set = {names}
else:
- names = set(names)
+ names_set = set(names)
if errors == "raise":
- self._assert_all_in_dataset(names)
+ self._assert_all_in_dataset(names_set)
# GH6505
other_names = set()
- for var in names:
+ for var in names_set:
maybe_midx = self._indexes.get(var, None)
if isinstance(maybe_midx, PandasMultiIndex):
idx_coord_names = set(maybe_midx.index.names + [maybe_midx.dim])
- idx_other_names = idx_coord_names - set(names)
+ idx_other_names = idx_coord_names - set(names_set)
other_names.update(idx_other_names)
if other_names:
- names |= set(other_names)
+ names_set |= set(other_names)
warnings.warn(
f"Deleting a single level of a MultiIndex is deprecated. Previously, this deleted all levels of a MultiIndex. "
f"Please also drop the following variables: {other_names!r} to avoid an error in the future.",
@@ -5848,11 +5893,11 @@ def drop_vars(
stacklevel=2,
)
- assert_no_index_corrupted(self.xindexes, names)
+ assert_no_index_corrupted(self.xindexes, names_set)
- variables = {k: v for k, v in self._variables.items() if k not in names}
+ variables = {k: v for k, v in self._variables.items() if k not in names_set}
coord_names = {k for k in self._coord_names if k in variables}
- indexes = {k: v for k, v in self._indexes.items() if k not in names}
+ indexes = {k: v for k, v in self._indexes.items() if k not in names_set}
return self._replace_with_new_dims(
variables, coord_names=coord_names, indexes=indexes
)
@@ -5934,10 +5979,9 @@ def drop(
raise ValueError('errors must be either "raise" or "ignore"')
if is_dict_like(labels) and not isinstance(labels, dict):
- warnings.warn(
- "dropping coordinates using `drop` is be deprecated; use drop_vars.",
- FutureWarning,
- stacklevel=2,
+ emit_user_level_warning(
+ "dropping coordinates using `drop` is deprecated; use drop_vars.",
+ DeprecationWarning,
)
return self.drop_vars(labels, errors=errors)
@@ -5947,11 +5991,13 @@ def drop(
labels = either_dict_or_kwargs(labels, labels_kwargs, "drop")
if dim is None and (is_scalar(labels) or isinstance(labels, Iterable)):
- warnings.warn(
- "dropping variables using `drop` will be deprecated; using drop_vars is encouraged.",
- PendingDeprecationWarning,
- stacklevel=2,
+ emit_user_level_warning(
+ "dropping variables using `drop` is deprecated; use drop_vars.",
+ DeprecationWarning,
)
+ # for mypy
+ if is_scalar(labels):
+ labels = [labels]
return self.drop_vars(labels, errors=errors)
if dim is not None:
warnings.warn(
@@ -5962,10 +6008,9 @@ def drop(
)
return self.drop_sel({dim: labels}, errors=errors, **labels_kwargs)
- warnings.warn(
- "dropping labels using `drop` will be deprecated; using drop_sel is encouraged.",
- PendingDeprecationWarning,
- stacklevel=2,
+ emit_user_level_warning(
+ "dropping labels using `drop` is deprecated; use `drop_sel` instead.",
+ DeprecationWarning,
)
return self.drop_sel(labels, errors=errors)
@@ -6309,7 +6354,7 @@ def dropna(
if subset is None:
subset = iter(self.data_vars)
- count = np.zeros(self.dims[dim], dtype=np.int64)
+ count = np.zeros(self.sizes[dim], dtype=np.int64)
size = np.int_(0) # for type checking
for k in subset:
@@ -6317,7 +6362,7 @@ def dropna(
if dim in array.dims:
dims = [d for d in array.dims if d != dim]
count += np.asarray(array.count(dims))
- size += math.prod([self.dims[d] for d in dims])
+ size += math.prod([self.sizes[d] for d in dims])
if thresh is not None:
mask = count >= thresh
@@ -7094,7 +7139,7 @@ def _normalize_dim_order(
f"Dataset: {list(self.dims)}"
)
- ordered_dims = {k: self.dims[k] for k in dim_order}
+ ordered_dims = {k: self.sizes[k] for k in dim_order}
return ordered_dims
@@ -7354,7 +7399,7 @@ def to_dask_dataframe(
var = self.variables[name]
except KeyError:
# dimension without a matching coordinate
- size = self.dims[name]
+ size = self.sizes[name]
data = da.arange(size, chunks=size, dtype=np.int64)
var = Variable((name,), data)
@@ -7427,7 +7472,7 @@ def to_dict(
d: dict = {
"coords": {},
"attrs": decode_numpy_dict_values(self.attrs),
- "dims": dict(self.dims),
+ "dims": dict(self.sizes),
"data_vars": {},
}
for k in self.coords:
@@ -7969,8 +8014,8 @@ def sortby(
variables = variables
arrays = [v if isinstance(v, DataArray) else self[v] for v in variables]
aligned_vars = align(self, *arrays, join="left")
- aligned_self = aligned_vars[0]
- aligned_other_vars: tuple[DataArray, ...] = aligned_vars[1:]
+ aligned_self = cast("Self", aligned_vars[0])
+ aligned_other_vars = cast(tuple[DataArray, ...], aligned_vars[1:])
vars_by_dim = defaultdict(list)
for data_array in aligned_other_vars:
if data_array.ndim != 1:
@@ -9510,6 +9555,68 @@ def argmax(self, dim: Hashable | None = None, **kwargs) -> Self:
"Dataset.argmin() with a sequence or ... for dim"
)
+ def eval(
+ self,
+ statement: str,
+ *,
+ parser: QueryParserOptions = "pandas",
+ ) -> Self | T_DataArray:
+ """
+ Calculate an expression supplied as a string in the context of the dataset.
+
+ This is currently experimental; the API may change particularly around
+ assignments, which currently returnn a ``Dataset`` with the additional variable.
+ Currently only the ``python`` engine is supported, which has the same
+ performance as executing in python.
+
+ Parameters
+ ----------
+ statement : str
+ String containing the Python-like expression to evaluate.
+
+ Returns
+ -------
+ result : Dataset or DataArray, depending on whether ``statement`` contains an
+ assignment.
+
+ Examples
+ --------
+ >>> ds = xr.Dataset(
+ ... {"a": ("x", np.arange(0, 5, 1)), "b": ("x", np.linspace(0, 1, 5))}
+ ... )
+ >>> ds
+
+ Dimensions: (x: 5)
+ Dimensions without coordinates: x
+ Data variables:
+ a (x) int64 0 1 2 3 4
+ b (x) float64 0.0 0.25 0.5 0.75 1.0
+
+ >>> ds.eval("a + b")
+
+ array([0. , 1.25, 2.5 , 3.75, 5. ])
+ Dimensions without coordinates: x
+
+ >>> ds.eval("c = a + b")
+
+ Dimensions: (x: 5)
+ Dimensions without coordinates: x
+ Data variables:
+ a (x) int64 0 1 2 3 4
+ b (x) float64 0.0 0.25 0.5 0.75 1.0
+ c (x) float64 0.0 1.25 2.5 3.75 5.0
+ """
+
+ return pd.eval(
+ statement,
+ resolvers=[self],
+ target=self,
+ parser=parser,
+ # Because numexpr returns a numpy array, using that engine results in
+ # different behavior. We'd be very open to a contribution handling this.
+ engine="python",
+ )
+
def query(
self,
queries: Mapping[Any, Any] | None = None,
@@ -9612,8 +9719,8 @@ def curvefit(
func: Callable[..., Any],
reduce_dims: Dims = None,
skipna: bool = True,
- p0: dict[str, float | DataArray] | None = None,
- bounds: dict[str, tuple[float | DataArray, float | DataArray]] | None = None,
+ p0: Mapping[str, float | DataArray] | None = None,
+ bounds: Mapping[str, tuple[float | DataArray, float | DataArray]] | None = None,
param_names: Sequence[str] | None = None,
errors: ErrorOptions = "raise",
kwargs: dict[str, Any] | None = None,
@@ -10262,14 +10369,60 @@ def rolling(
See Also
--------
- core.rolling.DatasetRolling
+ Dataset.cumulative
DataArray.rolling
+ core.rolling.DatasetRolling
"""
from xarray.core.rolling import DatasetRolling
dim = either_dict_or_kwargs(dim, window_kwargs, "rolling")
return DatasetRolling(self, dim, min_periods=min_periods, center=center)
+ def cumulative(
+ self,
+ dim: str | Iterable[Hashable],
+ min_periods: int = 1,
+ ) -> DatasetRolling:
+ """
+ Accumulating object for Datasets
+
+ Parameters
+ ----------
+ dims : iterable of hashable
+ The name(s) of the dimensions to create the cumulative window along
+ min_periods : int, default: 1
+ Minimum number of observations in window required to have a value
+ (otherwise result is NA). The default is 1 (note this is different
+ from ``Rolling``, whose default is the size of the window).
+
+ Returns
+ -------
+ core.rolling.DatasetRolling
+
+ See Also
+ --------
+ Dataset.rolling
+ DataArray.cumulative
+ core.rolling.DatasetRolling
+ """
+ from xarray.core.rolling import DatasetRolling
+
+ if isinstance(dim, str):
+ if dim not in self.dims:
+ raise ValueError(
+ f"Dimension {dim} not found in data dimensions: {self.dims}"
+ )
+ dim = {dim: self.sizes[dim]}
+ else:
+ missing_dims = set(dim) - set(self.dims)
+ if missing_dims:
+ raise ValueError(
+ f"Dimensions {missing_dims} not found in data dimensions: {self.dims}"
+ )
+ dim = {d: self.sizes[d] for d in dim}
+
+ return DatasetRolling(self, dim, min_periods=min_periods, center=False)
+
def coarsen(
self,
dim: Mapping[Any, int] | None = None,
@@ -10374,6 +10527,12 @@ def resample(
loffset : timedelta or str, optional
Offset used to adjust the resampled time labels. Some pandas date
offset strings are supported.
+
+ .. deprecated:: 2023.03.0
+ Following pandas, the ``loffset`` parameter is deprecated in favor
+ of using time offset arithmetic, and will be removed in a future
+ version of xarray.
+
restore_coord_dims : bool, optional
If True, also restore the dimension order of multi-dimensional
coordinates.
diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py
index b9f7db9737f..84cfd7f6fdc 100644
--- a/xarray/core/duck_array_ops.py
+++ b/xarray/core/duck_array_ops.py
@@ -31,8 +31,10 @@
from numpy import concatenate as _concatenate
from numpy.core.multiarray import normalize_axis_index # type: ignore[attr-defined]
from numpy.lib.stride_tricks import sliding_window_view # noqa
+from packaging.version import Version
-from xarray.core import dask_array_ops, dtypes, nputils
+from xarray.core import dask_array_ops, dtypes, nputils, pycompat
+from xarray.core.options import OPTIONS
from xarray.core.parallelcompat import get_chunked_array_type, is_chunked_array
from xarray.core.pycompat import array_type, is_duck_dask_array
from xarray.core.utils import is_duck_array, module_available
@@ -333,7 +335,10 @@ def fillna(data, other):
def concatenate(arrays, axis=0):
"""concatenate() with better dtype promotion rules."""
- if hasattr(arrays[0], "__array_namespace__"):
+ # TODO: remove the additional check once `numpy` adds `concat` to its array namespace
+ if hasattr(arrays[0], "__array_namespace__") and not isinstance(
+ arrays[0], np.ndarray
+ ):
xp = get_array_namespace(arrays[0])
return xp.concat(as_shared_dtype(arrays, xp=xp), axis=axis)
return _concatenate(as_shared_dtype(arrays), axis=axis)
@@ -688,13 +693,44 @@ def least_squares(lhs, rhs, rcond=None, skipna=False):
return nputils.least_squares(lhs, rhs, rcond=rcond, skipna=skipna)
-def push(array, n, axis):
- from bottleneck import push
+def _push(array, n: int | None = None, axis: int = -1):
+ """
+ Use either bottleneck or numbagg depending on options & what's available
+ """
+ if not OPTIONS["use_bottleneck"] and not OPTIONS["use_numbagg"]:
+ raise RuntimeError(
+ "ffill & bfill requires bottleneck or numbagg to be enabled."
+ " Call `xr.set_options(use_bottleneck=True)` or `xr.set_options(use_numbagg=True)` to enable one."
+ )
+ if OPTIONS["use_numbagg"] and module_available("numbagg"):
+ import numbagg
+
+ if pycompat.mod_version("numbagg") < Version("0.6.2"):
+ warnings.warn(
+ f"numbagg >= 0.6.2 is required for bfill & ffill; {pycompat.mod_version('numbagg')} is installed. We'll attempt with bottleneck instead."
+ )
+ else:
+ return numbagg.ffill(array, limit=n, axis=axis)
+
+ # work around for bottleneck 178
+ limit = n if n is not None else array.shape[axis]
+
+ import bottleneck as bn
+
+ return bn.push(array, limit, axis)
+
+
+def push(array, n, axis):
+ if not OPTIONS["use_bottleneck"] and not OPTIONS["use_numbagg"]:
+ raise RuntimeError(
+ "ffill & bfill requires bottleneck or numbagg to be enabled."
+ " Call `xr.set_options(use_bottleneck=True)` or `xr.set_options(use_numbagg=True)` to enable one."
+ )
if is_duck_dask_array(array):
return dask_array_ops.push(array, n, axis)
else:
- return push(array, n, axis)
+ return _push(array, n, axis)
def _first_last_wrapper(array, *, axis, op, keepdims):
diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py
index a915e9acbf3..92bfe2fbfc4 100644
--- a/xarray/core/formatting.py
+++ b/xarray/core/formatting.py
@@ -6,7 +6,7 @@
import functools
import math
from collections import defaultdict
-from collections.abc import Collection, Hashable
+from collections.abc import Collection, Hashable, Sequence
from datetime import datetime, timedelta
from itertools import chain, zip_longest
from reprlib import recursive_repr
@@ -357,7 +357,7 @@ def summarize_attr(key, value, col_width=None):
def _calculate_col_width(col_items):
- max_name_length = max(len(str(s)) for s in col_items) if col_items else 0
+ max_name_length = max((len(str(s)) for s in col_items), default=0)
col_width = max(max_name_length, 7) + 6
return col_width
@@ -739,7 +739,7 @@ def dataset_repr(ds):
def diff_dim_summary(a, b):
- if a.dims != b.dims:
+ if a.sizes != b.sizes:
return f"Differing dimensions:\n ({dim_summary(a)}) != ({dim_summary(b)})"
else:
return ""
@@ -937,3 +937,16 @@ def diff_dataset_repr(a, b, compat):
summary.append(diff_attrs_repr(a.attrs, b.attrs, compat))
return "\n".join(summary)
+
+
+def shorten_list_repr(items: Sequence, max_items: int) -> str:
+ if len(items) <= max_items:
+ return repr(items)
+ else:
+ first_half = repr(items[: max_items // 2])[
+ 1:-1
+ ] # Convert to string and remove brackets
+ second_half = repr(items[-max_items // 2 :])[
+ 1:-1
+ ] # Convert to string and remove brackets
+ return f"[{first_half}, ..., {second_half}]"
diff --git a/xarray/core/formatting_html.py b/xarray/core/formatting_html.py
index 3627554cf57..efd74111823 100644
--- a/xarray/core/formatting_html.py
+++ b/xarray/core/formatting_html.py
@@ -37,17 +37,18 @@ def short_data_repr_html(array) -> str:
return f"
{text}
"
-def format_dims(dims, dims_with_index) -> str:
- if not dims:
+def format_dims(dim_sizes, dims_with_index) -> str:
+ if not dim_sizes:
return ""
dim_css_map = {
- dim: " class='xr-has-index'" if dim in dims_with_index else "" for dim in dims
+ dim: " class='xr-has-index'" if dim in dims_with_index else ""
+ for dim in dim_sizes
}
dims_li = "".join(
f"
" f"{escape(str(dim))}: {size}
"
- for dim, size in dims.items()
+ for dim, size in dim_sizes.items()
)
return f"
{dims_li}
"
@@ -204,7 +205,7 @@ def _mapping_section(
def dim_section(obj) -> str:
- dim_list = format_dims(obj.dims, obj.xindexes.dims)
+ dim_list = format_dims(obj.sizes, obj.xindexes.dims)
return collapsible_section(
"Dimensions", inline_details=dim_list, enabled=False, collapsed=True
diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py
index 8c81d3e6a96..15bd8d1e35b 100644
--- a/xarray/core/groupby.py
+++ b/xarray/core/groupby.py
@@ -36,6 +36,7 @@
from xarray.core.pycompat import integer_types
from xarray.core.types import Dims, QuantileMethods, T_DataArray, T_Xarray
from xarray.core.utils import (
+ FrozenMappingWarningOnValuesAccess,
either_dict_or_kwargs,
hashable,
is_scalar,
@@ -1519,7 +1520,7 @@ def dims(self) -> Frozen[Hashable, int]:
if self._dims is None:
self._dims = self._obj.isel({self._group_dim: self._group_indices[0]}).dims
- return self._dims
+ return FrozenMappingWarningOnValuesAccess(self._dims)
def map(
self,
diff --git a/xarray/core/missing.py b/xarray/core/missing.py
index 90a9dd2e76c..b55fd6049a6 100644
--- a/xarray/core/missing.py
+++ b/xarray/core/missing.py
@@ -14,7 +14,7 @@
from xarray.core.common import _contains_datetime_like_objects, ones_like
from xarray.core.computation import apply_ufunc
from xarray.core.duck_array_ops import datetime_to_numeric, push, timedelta_to_numeric
-from xarray.core.options import OPTIONS, _get_keep_attrs
+from xarray.core.options import _get_keep_attrs
from xarray.core.parallelcompat import get_chunked_array_type, is_chunked_array
from xarray.core.types import Interp1dOptions, InterpOptions
from xarray.core.utils import OrderedSet, is_scalar
@@ -413,11 +413,6 @@ def _bfill(arr, n=None, axis=-1):
def ffill(arr, dim=None, limit=None):
"""forward fill missing values"""
- if not OPTIONS["use_bottleneck"]:
- raise RuntimeError(
- "ffill requires bottleneck to be enabled."
- " Call `xr.set_options(use_bottleneck=True)` to enable it."
- )
axis = arr.get_axis_num(dim)
@@ -436,11 +431,6 @@ def ffill(arr, dim=None, limit=None):
def bfill(arr, dim=None, limit=None):
"""backfill missing values"""
- if not OPTIONS["use_bottleneck"]:
- raise RuntimeError(
- "bfill requires bottleneck to be enabled."
- " Call `xr.set_options(use_bottleneck=True)` to enable it."
- )
axis = arr.get_axis_num(dim)
diff --git a/xarray/core/nputils.py b/xarray/core/nputils.py
index 316a77ead6a..96e5548b9b4 100644
--- a/xarray/core/nputils.py
+++ b/xarray/core/nputils.py
@@ -1,12 +1,16 @@
from __future__ import annotations
import warnings
+from typing import Callable
import numpy as np
import pandas as pd
from numpy.core.multiarray import normalize_axis_index # type: ignore[attr-defined]
from packaging.version import Version
+from xarray.core import pycompat
+from xarray.core.utils import module_available
+
# remove once numpy 2.0 is the oldest supported version
try:
from numpy.exceptions import RankWarning # type: ignore[attr-defined,unused-ignore]
@@ -25,15 +29,6 @@
bn = np
_BOTTLENECK_AVAILABLE = False
-try:
- import numbagg
-
- _HAS_NUMBAGG = Version(numbagg.__version__) >= Version("0.5.0")
-except ImportError:
- # use numpy methods instead
- numbagg = np
- _HAS_NUMBAGG = False
-
def _select_along_axis(values, idx, axis):
other_ind = np.ix_(*[np.arange(s) for s in idx.shape])
@@ -171,17 +166,16 @@ def __setitem__(self, key, value):
self._array[key] = np.moveaxis(value, vindex_positions, mixed_positions)
-def _create_method(name, npmodule=np):
+def _create_method(name, npmodule=np) -> Callable:
def f(values, axis=None, **kwargs):
dtype = kwargs.get("dtype", None)
bn_func = getattr(bn, name, None)
- nba_func = getattr(numbagg, name, None)
if (
- _HAS_NUMBAGG
+ module_available("numbagg")
+ and pycompat.mod_version("numbagg") >= Version("0.5.0")
and OPTIONS["use_numbagg"]
and isinstance(values, np.ndarray)
- and nba_func is not None
# numbagg uses ddof=1 only, but numpy uses ddof=0 by default
and (("var" in name or "std" in name) and kwargs.get("ddof", 0) == 1)
# TODO: bool?
@@ -189,11 +183,15 @@ def f(values, axis=None, **kwargs):
# and values.dtype.isnative
and (dtype is None or np.dtype(dtype) == values.dtype)
):
- # numbagg does not take care dtype, ddof
- kwargs.pop("dtype", None)
- kwargs.pop("ddof", None)
- result = nba_func(values, axis=axis, **kwargs)
- elif (
+ import numbagg
+
+ nba_func = getattr(numbagg, name, None)
+ if nba_func is not None:
+ # numbagg does not take care dtype, ddof
+ kwargs.pop("dtype", None)
+ kwargs.pop("ddof", None)
+ return nba_func(values, axis=axis, **kwargs)
+ if (
_BOTTLENECK_AVAILABLE
and OPTIONS["use_bottleneck"]
and isinstance(values, np.ndarray)
diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py
index dd5232023a2..ef505b55345 100644
--- a/xarray/core/parallel.py
+++ b/xarray/core/parallel.py
@@ -4,19 +4,29 @@
import itertools
import operator
from collections.abc import Hashable, Iterable, Mapping, Sequence
-from typing import TYPE_CHECKING, Any, Callable
+from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict
import numpy as np
from xarray.core.alignment import align
+from xarray.core.coordinates import Coordinates
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
+from xarray.core.indexes import Index
+from xarray.core.merge import merge
from xarray.core.pycompat import is_dask_collection
if TYPE_CHECKING:
from xarray.core.types import T_Xarray
+class ExpectedDict(TypedDict):
+ shapes: dict[Hashable, int]
+ coords: set[Hashable]
+ data_vars: set[Hashable]
+ indexes: dict[Hashable, Index]
+
+
def unzip(iterable):
return zip(*iterable)
@@ -31,7 +41,9 @@ def assert_chunks_compatible(a: Dataset, b: Dataset):
def check_result_variables(
- result: DataArray | Dataset, expected: Mapping[str, Any], kind: str
+ result: DataArray | Dataset,
+ expected: ExpectedDict,
+ kind: Literal["coords", "data_vars"],
):
if kind == "coords":
nice_str = "coordinate"
@@ -186,8 +198,9 @@ def map_blocks(
Returns
-------
- A single DataArray or Dataset with dask backend, reassembled from the outputs of the
- function.
+ obj : same as obj
+ A single DataArray or Dataset with dask backend, reassembled from the outputs of the
+ function.
Notes
-----
@@ -253,7 +266,7 @@ def _wrapper(
args: list,
kwargs: dict,
arg_is_array: Iterable[bool],
- expected: dict,
+ expected: ExpectedDict,
):
"""
Wrapper function that receives datasets in args; converts to dataarrays when necessary;
@@ -344,33 +357,45 @@ def _wrapper(
for arg in aligned
)
+ merged_coordinates = merge([arg.coords for arg in aligned]).coords
+
_, npargs = unzip(
sorted(list(zip(xarray_indices, xarray_objs)) + others, key=lambda x: x[0])
)
# check that chunk sizes are compatible
input_chunks = dict(npargs[0].chunks)
- input_indexes = dict(npargs[0]._indexes)
for arg in xarray_objs[1:]:
assert_chunks_compatible(npargs[0], arg)
input_chunks.update(arg.chunks)
- input_indexes.update(arg._indexes)
+ coordinates: Coordinates
if template is None:
# infer template by providing zero-shaped arrays
template = infer_template(func, aligned[0], *args, **kwargs)
- template_indexes = set(template._indexes)
- preserved_indexes = template_indexes & set(input_indexes)
- new_indexes = template_indexes - set(input_indexes)
- indexes = {dim: input_indexes[dim] for dim in preserved_indexes}
- indexes.update({k: template._indexes[k] for k in new_indexes})
+ template_coords = set(template.coords)
+ preserved_coord_vars = template_coords & set(merged_coordinates)
+ new_coord_vars = template_coords - set(merged_coordinates)
+
+ preserved_coords = merged_coordinates.to_dataset()[preserved_coord_vars]
+ # preserved_coords contains all coordinates bariables that share a dimension
+ # with any index variable in preserved_indexes
+ # Drop any unneeded vars in a second pass, this is required for e.g.
+ # if the mapped function were to drop a non-dimension coordinate variable.
+ preserved_coords = preserved_coords.drop_vars(
+ tuple(k for k in preserved_coords.variables if k not in template_coords)
+ )
+
+ coordinates = merge(
+ (preserved_coords, template.coords.to_dataset()[new_coord_vars])
+ ).coords
output_chunks: Mapping[Hashable, tuple[int, ...]] = {
dim: input_chunks[dim] for dim in template.dims if dim in input_chunks
}
else:
# template xarray object has been provided with proper sizes and chunk shapes
- indexes = dict(template._indexes)
+ coordinates = template.coords
output_chunks = template.chunksizes
if not output_chunks:
raise ValueError(
@@ -472,6 +497,9 @@ def subset_dataset_to_block(
return (Dataset, (dict, data_vars), (dict, coords), dataset.attrs)
+ # variable names that depend on the computation. Currently, indexes
+ # cannot be modified in the mapped function, so we exclude thos
+ computed_variables = set(template.variables) - set(coordinates.xindexes)
# iterate over all possible chunk combinations
for chunk_tuple in itertools.product(*ichunk.values()):
# mapping from dimension name to chunk index
@@ -484,19 +512,23 @@ def subset_dataset_to_block(
for isxr, arg in zip(is_xarray, npargs)
]
- # expected["shapes", "coords", "data_vars", "indexes"] are used to
# raise nice error messages in _wrapper
- expected = {}
- # input chunk 0 along a dimension maps to output chunk 0 along the same dimension
- # even if length of dimension is changed by the applied function
- expected["shapes"] = {
- k: output_chunks[k][v] for k, v in chunk_index.items() if k in output_chunks
- }
- expected["data_vars"] = set(template.data_vars.keys()) # type: ignore[assignment]
- expected["coords"] = set(template.coords.keys()) # type: ignore[assignment]
- expected["indexes"] = {
- dim: indexes[dim][_get_chunk_slicer(dim, chunk_index, output_chunk_bounds)]
- for dim in indexes
+ expected: ExpectedDict = {
+ # input chunk 0 along a dimension maps to output chunk 0 along the same dimension
+ # even if length of dimension is changed by the applied function
+ "shapes": {
+ k: output_chunks[k][v]
+ for k, v in chunk_index.items()
+ if k in output_chunks
+ },
+ "data_vars": set(template.data_vars.keys()),
+ "coords": set(template.coords.keys()),
+ "indexes": {
+ dim: coordinates.xindexes[dim][
+ _get_chunk_slicer(dim, chunk_index, output_chunk_bounds)
+ ]
+ for dim in coordinates.xindexes
+ },
}
from_wrapper = (gname,) + chunk_tuple
@@ -504,9 +536,8 @@ def subset_dataset_to_block(
# mapping from variable name to dask graph key
var_key_map: dict[Hashable, str] = {}
- for name, variable in template.variables.items():
- if name in indexes:
- continue
+ for name in computed_variables:
+ variable = template.variables[name]
gname_l = f"{name}-{gname}"
var_key_map[name] = gname_l
@@ -542,12 +573,7 @@ def subset_dataset_to_block(
},
)
- # TODO: benbovy - flexible indexes: make it work with custom indexes
- # this will need to pass both indexes and coords to the Dataset constructor
- result = Dataset(
- coords={k: idx.to_pandas_index() for k, idx in indexes.items()},
- attrs=template.attrs,
- )
+ result = Dataset(coords=coordinates, attrs=template.attrs)
for index in result._indexes:
result[index].attrs = template[index].attrs
diff --git a/xarray/core/parallelcompat.py b/xarray/core/parallelcompat.py
index 333059e00ae..37542925dde 100644
--- a/xarray/core/parallelcompat.py
+++ b/xarray/core/parallelcompat.py
@@ -403,6 +403,43 @@ def reduction(
"""
raise NotImplementedError()
+ def scan(
+ self,
+ func: Callable,
+ binop: Callable,
+ ident: float,
+ arr: T_ChunkedArray,
+ axis: int | None = None,
+ dtype: np.dtype | None = None,
+ **kwargs,
+ ) -> T_ChunkedArray:
+ """
+ General version of a 1D scan, also known as a cumulative array reduction.
+
+ Used in ``ffill`` and ``bfill`` in xarray.
+
+ Parameters
+ ----------
+ func: callable
+ Cumulative function like np.cumsum or np.cumprod
+ binop: callable
+ Associated binary operator like ``np.cumsum->add`` or ``np.cumprod->mul``
+ ident: Number
+ Associated identity like ``np.cumsum->0`` or ``np.cumprod->1``
+ arr: dask Array
+ axis: int, optional
+ dtype: dtype
+
+ Returns
+ -------
+ Chunked array
+
+ See also
+ --------
+ dask.array.cumreduction
+ """
+ raise NotImplementedError()
+
@abstractmethod
def apply_gufunc(
self,
diff --git a/xarray/core/pycompat.py b/xarray/core/pycompat.py
index bc8b61164f1..32ef408f7cc 100644
--- a/xarray/core/pycompat.py
+++ b/xarray/core/pycompat.py
@@ -12,7 +12,7 @@
integer_types = (int, np.integer)
if TYPE_CHECKING:
- ModType = Literal["dask", "pint", "cupy", "sparse", "cubed"]
+ ModType = Literal["dask", "pint", "cupy", "sparse", "cubed", "numbagg"]
DuckArrayTypes = tuple[type[Any], ...] # TODO: improve this? maybe Generic
@@ -47,6 +47,9 @@ def __init__(self, mod: ModType) -> None:
duck_array_type = (duck_array_module.SparseArray,)
elif mod == "cubed":
duck_array_type = (duck_array_module.Array,)
+ # Not a duck array module, but using this system regardless, to get lazy imports
+ elif mod == "numbagg":
+ duck_array_type = ()
else:
raise NotImplementedError
diff --git a/xarray/core/resample.py b/xarray/core/resample.py
index d78676b188e..c93faa31612 100644
--- a/xarray/core/resample.py
+++ b/xarray/core/resample.py
@@ -63,7 +63,7 @@ def _drop_coords(self) -> T_Xarray:
obj = self._obj
for k, v in obj.coords.items():
if k != self._dim and self._dim in v.dims:
- obj = obj.drop_vars(k)
+ obj = obj.drop_vars([k])
return obj
def pad(self, tolerance: float | Iterable[float] | None = None) -> T_Xarray:
@@ -244,7 +244,7 @@ def map(
# dimension, then we need to do so before we can rename the proxy
# dimension we used.
if self._dim in combined.coords:
- combined = combined.drop_vars(self._dim)
+ combined = combined.drop_vars([self._dim])
if RESAMPLE_DIM in combined.dims:
combined = combined.rename({RESAMPLE_DIM: self._dim})
diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py
index 8f21fe37072..819c31642d0 100644
--- a/xarray/core/rolling.py
+++ b/xarray/core/rolling.py
@@ -8,13 +8,14 @@
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar
import numpy as np
+from packaging.version import Version
-from xarray.core import dtypes, duck_array_ops, utils
+from xarray.core import dtypes, duck_array_ops, pycompat, utils
from xarray.core.arithmetic import CoarsenArithmetic
from xarray.core.options import OPTIONS, _get_keep_attrs
from xarray.core.pycompat import is_duck_dask_array
from xarray.core.types import CoarsenBoundaryOptions, SideOptions, T_Xarray
-from xarray.core.utils import either_dict_or_kwargs
+from xarray.core.utils import either_dict_or_kwargs, module_available
try:
import bottleneck
@@ -145,7 +146,13 @@ def _reduce_method( # type: ignore[misc]
name: str, fillna: Any, rolling_agg_func: Callable | None = None
) -> Callable[..., T_Xarray]:
"""Constructs reduction methods built on a numpy reduction function (e.g. sum),
- a bottleneck reduction function (e.g. move_sum), or a Rolling reduction (_mean).
+ a numbagg reduction function (e.g. move_sum), a bottleneck reduction function
+ (e.g. move_sum), or a Rolling reduction (_mean).
+
+ The logic here for which function to run is quite diffuse, across this method &
+ _array_reduce. Arguably we could refactor this. But one constraint is that we
+ need context of xarray options, of the functions each library offers, of
+ the array (e.g. dtype).
"""
if rolling_agg_func:
array_agg_func = None
@@ -153,14 +160,21 @@ def _reduce_method( # type: ignore[misc]
array_agg_func = getattr(duck_array_ops, name)
bottleneck_move_func = getattr(bottleneck, "move_" + name, None)
+ if module_available("numbagg"):
+ import numbagg
+
+ numbagg_move_func = getattr(numbagg, "move_" + name, None)
+ else:
+ numbagg_move_func = None
def method(self, keep_attrs=None, **kwargs):
keep_attrs = self._get_keep_attrs(keep_attrs)
- return self._numpy_or_bottleneck_reduce(
- array_agg_func,
- bottleneck_move_func,
- rolling_agg_func,
+ return self._array_reduce(
+ array_agg_func=array_agg_func,
+ bottleneck_move_func=bottleneck_move_func,
+ numbagg_move_func=numbagg_move_func,
+ rolling_agg_func=rolling_agg_func,
keep_attrs=keep_attrs,
fillna=fillna,
**kwargs,
@@ -510,9 +524,47 @@ def _counts(self, keep_attrs: bool | None) -> DataArray:
)
return counts
- def _bottleneck_reduce(self, func, keep_attrs, **kwargs):
- from xarray.core.dataarray import DataArray
+ def _numbagg_reduce(self, func, keep_attrs, **kwargs):
+ # Some of this is copied from `_bottleneck_reduce`, we could reduce this as part
+ # of a wider refactor.
+
+ axis = self.obj.get_axis_num(self.dim[0])
+ padded = self.obj.variable
+ if self.center[0]:
+ if is_duck_dask_array(padded.data):
+ # workaround to make the padded chunk size larger than
+ # self.window - 1
+ shift = -(self.window[0] + 1) // 2
+ offset = (self.window[0] - 1) // 2
+ valid = (slice(None),) * axis + (
+ slice(offset, offset + self.obj.shape[axis]),
+ )
+ else:
+ shift = (-self.window[0] // 2) + 1
+ valid = (slice(None),) * axis + (slice(-shift, None),)
+ padded = padded.pad({self.dim[0]: (0, -shift)}, mode="constant")
+
+ if is_duck_dask_array(padded.data) and False:
+ raise AssertionError("should not be reachable")
+ else:
+ values = func(
+ padded.data,
+ window=self.window[0],
+ min_count=self.min_periods,
+ axis=axis,
+ )
+
+ if self.center[0]:
+ values = values[valid]
+
+ attrs = self.obj.attrs if keep_attrs else {}
+
+ return self.obj.__class__(
+ values, self.obj.coords, attrs=attrs, name=self.obj.name
+ )
+
+ def _bottleneck_reduce(self, func, keep_attrs, **kwargs):
# bottleneck doesn't allow min_count to be 0, although it should
# work the same as if min_count = 1
# Note bottleneck only works with 1d-rolling.
@@ -550,12 +602,15 @@ def _bottleneck_reduce(self, func, keep_attrs, **kwargs):
attrs = self.obj.attrs if keep_attrs else {}
- return DataArray(values, self.obj.coords, attrs=attrs, name=self.obj.name)
+ return self.obj.__class__(
+ values, self.obj.coords, attrs=attrs, name=self.obj.name
+ )
- def _numpy_or_bottleneck_reduce(
+ def _array_reduce(
self,
array_agg_func,
bottleneck_move_func,
+ numbagg_move_func,
rolling_agg_func,
keep_attrs,
fillna,
@@ -571,6 +626,35 @@ def _numpy_or_bottleneck_reduce(
)
del kwargs["dim"]
+ if (
+ OPTIONS["use_numbagg"]
+ and module_available("numbagg")
+ and pycompat.mod_version("numbagg") >= Version("0.6.3")
+ and numbagg_move_func is not None
+ # TODO: we could at least allow this for the equivalent of `apply_ufunc`'s
+ # "parallelized". `rolling_exp` does this, as an example (but rolling_exp is
+ # much simpler)
+ and not is_duck_dask_array(self.obj.data)
+ # Numbagg doesn't handle object arrays and generally has dtype consistency,
+ # so doesn't deal well with bool arrays which are expected to change type.
+ and self.obj.data.dtype.kind not in "ObMm"
+ # TODO: we could also allow this, probably as part of a refactoring of this
+ # module, so we can use the machinery in `self.reduce`.
+ and self.ndim == 1
+ ):
+ import numbagg
+
+ # Numbagg has a default ddof of 1. I (@max-sixty) think we should make
+ # this the default in xarray too, but until we do, don't use numbagg for
+ # std and var unless ddof is set to 1.
+ if (
+ numbagg_move_func not in [numbagg.move_std, numbagg.move_var]
+ or kwargs.get("ddof") == 1
+ ):
+ return self._numbagg_reduce(
+ numbagg_move_func, keep_attrs=keep_attrs, **kwargs
+ )
+
if (
OPTIONS["use_bottleneck"]
and bottleneck_move_func is not None
@@ -583,8 +667,10 @@ def _numpy_or_bottleneck_reduce(
return self._bottleneck_reduce(
bottleneck_move_func, keep_attrs=keep_attrs, **kwargs
)
+
if rolling_agg_func:
return rolling_agg_func(self, keep_attrs=self._get_keep_attrs(keep_attrs))
+
if fillna is not None:
if fillna is dtypes.INF:
fillna = dtypes.get_pos_infinity(self.obj.dtype, max_for_int=True)
@@ -705,7 +791,7 @@ def _counts(self, keep_attrs: bool | None) -> Dataset:
DataArrayRolling._counts, keep_attrs=keep_attrs
)
- def _numpy_or_bottleneck_reduce(
+ def _array_reduce(
self,
array_agg_func,
bottleneck_move_func,
@@ -715,7 +801,7 @@ def _numpy_or_bottleneck_reduce(
):
return self._dataset_implementation(
functools.partial(
- DataArrayRolling._numpy_or_bottleneck_reduce,
+ DataArrayRolling._array_reduce,
array_agg_func=array_agg_func,
bottleneck_move_func=bottleneck_move_func,
rolling_agg_func=rolling_agg_func,
diff --git a/xarray/core/rolling_exp.py b/xarray/core/rolling_exp.py
index c8160cefef3..144e26a86b2 100644
--- a/xarray/core/rolling_exp.py
+++ b/xarray/core/rolling_exp.py
@@ -6,18 +6,12 @@
import numpy as np
from packaging.version import Version
+from xarray.core import pycompat
from xarray.core.computation import apply_ufunc
from xarray.core.options import _get_keep_attrs
from xarray.core.pdcompat import count_not_none
from xarray.core.types import T_DataWithCoords
-
-try:
- import numbagg
- from numbagg import move_exp_nanmean, move_exp_nansum
-
- _NUMBAGG_VERSION: Version | None = Version(numbagg.__version__)
-except ImportError:
- _NUMBAGG_VERSION = None
+from xarray.core.utils import module_available
def _get_alpha(
@@ -25,51 +19,34 @@ def _get_alpha(
span: float | None = None,
halflife: float | None = None,
alpha: float | None = None,
-) -> float:
- # pandas defines in terms of com (converting to alpha in the algo)
- # so use its function to get a com and then convert to alpha
-
- com = _get_center_of_mass(com, span, halflife, alpha)
- return 1 / (1 + com)
-
-
-def _get_center_of_mass(
- comass: float | None,
- span: float | None,
- halflife: float | None,
- alpha: float | None,
) -> float:
"""
- Vendored from pandas.core.window.common._get_center_of_mass
-
- See licenses/PANDAS_LICENSE for the function's license
+ Convert com, span, halflife to alpha.
"""
- valid_count = count_not_none(comass, span, halflife, alpha)
+ valid_count = count_not_none(com, span, halflife, alpha)
if valid_count > 1:
- raise ValueError("comass, span, halflife, and alpha are mutually exclusive")
+ raise ValueError("com, span, halflife, and alpha are mutually exclusive")
- # Convert to center of mass; domain checks ensure 0 < alpha <= 1
- if comass is not None:
- if comass < 0:
- raise ValueError("comass must satisfy: comass >= 0")
+ # Convert to alpha
+ if com is not None:
+ if com < 0:
+ raise ValueError("commust satisfy: com>= 0")
+ return 1 / (com + 1)
elif span is not None:
if span < 1:
raise ValueError("span must satisfy: span >= 1")
- comass = (span - 1) / 2.0
+ return 2 / (span + 1)
elif halflife is not None:
if halflife <= 0:
raise ValueError("halflife must satisfy: halflife > 0")
- decay = 1 - np.exp(np.log(0.5) / halflife)
- comass = 1 / decay - 1
+ return 1 - np.exp(np.log(0.5) / halflife)
elif alpha is not None:
- if alpha <= 0 or alpha > 1:
+ if not 0 < alpha <= 1:
raise ValueError("alpha must satisfy: 0 < alpha <= 1")
- comass = (1.0 - alpha) / alpha
+ return alpha
else:
raise ValueError("Must pass one of comass, span, halflife, or alpha")
- return float(comass)
-
class RollingExp(Generic[T_DataWithCoords]):
"""
@@ -100,17 +77,17 @@ def __init__(
window_type: str = "span",
min_weight: float = 0.0,
):
- if _NUMBAGG_VERSION is None:
+ if not module_available("numbagg"):
raise ImportError(
"numbagg >= 0.2.1 is required for rolling_exp but currently numbagg is not installed"
)
- elif _NUMBAGG_VERSION < Version("0.2.1"):
+ elif pycompat.mod_version("numbagg") < Version("0.2.1"):
raise ImportError(
- f"numbagg >= 0.2.1 is required for rolling_exp but currently version {_NUMBAGG_VERSION} is installed"
+ f"numbagg >= 0.2.1 is required for rolling_exp but currently version {pycompat.mod_version('numbagg')} is installed"
)
- elif _NUMBAGG_VERSION < Version("0.3.1") and min_weight > 0:
+ elif pycompat.mod_version("numbagg") < Version("0.3.1") and min_weight > 0:
raise ImportError(
- f"numbagg >= 0.3.1 is required for `min_weight > 0` within `.rolling_exp` but currently version {_NUMBAGG_VERSION} is installed"
+ f"numbagg >= 0.3.1 is required for `min_weight > 0` within `.rolling_exp` but currently version {pycompat.mod_version('numbagg')} is installed"
)
self.obj: T_DataWithCoords = obj
@@ -144,13 +121,15 @@ def mean(self, keep_attrs: bool | None = None) -> T_DataWithCoords:
Dimensions without coordinates: x
"""
+ import numbagg
+
if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=True)
dim_order = self.obj.dims
return apply_ufunc(
- move_exp_nanmean,
+ numbagg.move_exp_nanmean,
self.obj,
input_core_dims=[[self.dim]],
kwargs=self.kwargs,
@@ -180,13 +159,15 @@ def sum(self, keep_attrs: bool | None = None) -> T_DataWithCoords:
Dimensions without coordinates: x
"""
+ import numbagg
+
if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=True)
dim_order = self.obj.dims
return apply_ufunc(
- move_exp_nansum,
+ numbagg.move_exp_nansum,
self.obj,
input_core_dims=[[self.dim]],
kwargs=self.kwargs,
@@ -211,10 +192,12 @@ def std(self) -> T_DataWithCoords:
Dimensions without coordinates: x
"""
- if _NUMBAGG_VERSION is None or _NUMBAGG_VERSION < Version("0.4.0"):
+ if pycompat.mod_version("numbagg") < Version("0.4.0"):
raise ImportError(
- f"numbagg >= 0.4.0 is required for rolling_exp().std(), currently {_NUMBAGG_VERSION} is installed"
+ f"numbagg >= 0.4.0 is required for rolling_exp().std(), currently {pycompat.mod_version('numbagg')} is installed"
)
+ import numbagg
+
dim_order = self.obj.dims
return apply_ufunc(
@@ -242,12 +225,12 @@ def var(self) -> T_DataWithCoords:
array([ nan, 0. , 0.46153846, 0.18461538, 0.06446281])
Dimensions without coordinates: x
"""
-
- if _NUMBAGG_VERSION is None or _NUMBAGG_VERSION < Version("0.4.0"):
+ if pycompat.mod_version("numbagg") < Version("0.4.0"):
raise ImportError(
- f"numbagg >= 0.4.0 is required for rolling_exp().var(), currently {_NUMBAGG_VERSION} is installed"
+ f"numbagg >= 0.4.0 is required for rolling_exp().var(), currently {pycompat.mod_version('numbagg')} is installed"
)
dim_order = self.obj.dims
+ import numbagg
return apply_ufunc(
numbagg.move_exp_nanvar,
@@ -275,11 +258,12 @@ def cov(self, other: T_DataWithCoords) -> T_DataWithCoords:
Dimensions without coordinates: x
"""
- if _NUMBAGG_VERSION is None or _NUMBAGG_VERSION < Version("0.4.0"):
+ if pycompat.mod_version("numbagg") < Version("0.4.0"):
raise ImportError(
- f"numbagg >= 0.4.0 is required for rolling_exp().cov(), currently {_NUMBAGG_VERSION} is installed"
+ f"numbagg >= 0.4.0 is required for rolling_exp().cov(), currently {pycompat.mod_version('numbagg')} is installed"
)
dim_order = self.obj.dims
+ import numbagg
return apply_ufunc(
numbagg.move_exp_nancov,
@@ -308,11 +292,12 @@ def corr(self, other: T_DataWithCoords) -> T_DataWithCoords:
Dimensions without coordinates: x
"""
- if _NUMBAGG_VERSION is None or _NUMBAGG_VERSION < Version("0.4.0"):
+ if pycompat.mod_version("numbagg") < Version("0.4.0"):
raise ImportError(
- f"numbagg >= 0.4.0 is required for rolling_exp().cov(), currently {_NUMBAGG_VERSION} is installed"
+ f"numbagg >= 0.4.0 is required for rolling_exp().corr(), currently {pycompat.mod_version('numbagg')} is installed"
)
dim_order = self.obj.dims
+ import numbagg
return apply_ufunc(
numbagg.move_exp_nancorr,
diff --git a/xarray/core/types.py b/xarray/core/types.py
index 1be5b00c43f..06ad65679d8 100644
--- a/xarray/core/types.py
+++ b/xarray/core/types.py
@@ -173,7 +173,8 @@ def copy(
# Temporary placeholder for indicating an array api compliant type.
# hopefully in the future we can narrow this down more:
-T_DuckArray = TypeVar("T_DuckArray", bound=Any)
+T_DuckArray = TypeVar("T_DuckArray", bound=Any, covariant=True)
+
ScalarOrArray = Union["ArrayLike", np.generic, np.ndarray, "DaskArray"]
VarCompatible = Union["Variable", "ScalarOrArray"]
@@ -282,3 +283,6 @@ def copy(
"midpoint",
"nearest",
]
+
+
+ZarrWriteModes = Literal["w", "w-", "a", "a-", "r+", "r"]
diff --git a/xarray/core/utils.py b/xarray/core/utils.py
index ad86b2c7fec..00c84d4c10c 100644
--- a/xarray/core/utils.py
+++ b/xarray/core/utils.py
@@ -50,12 +50,15 @@
Collection,
Container,
Hashable,
+ ItemsView,
Iterable,
Iterator,
+ KeysView,
Mapping,
MutableMapping,
MutableSet,
Sequence,
+ ValuesView,
)
from enum import Enum
from typing import (
@@ -114,6 +117,8 @@ def get_valid_numpy_dtype(array: np.ndarray | pd.Index):
elif hasattr(array, "categories"):
# category isn't a real numpy dtype
dtype = array.categories.dtype
+ if not is_valid_numpy_dtype(dtype):
+ dtype = np.dtype("O")
elif not is_valid_numpy_dtype(array.dtype):
dtype = np.dtype("O")
else:
@@ -471,6 +476,57 @@ def FrozenDict(*args, **kwargs) -> Frozen:
return Frozen(dict(*args, **kwargs))
+class FrozenMappingWarningOnValuesAccess(Frozen[K, V]):
+ """
+ Class which behaves like a Mapping but warns if the values are accessed.
+
+ Temporary object to aid in deprecation cycle of `Dataset.dims` (see GH issue #8496).
+ `Dataset.dims` is being changed from returning a mapping of dimension names to lengths to just
+ returning a frozen set of dimension names (to increase consistency with `DataArray.dims`).
+ This class retains backwards compatibility but raises a warning only if the return value
+ of ds.dims is used like a dictionary (i.e. it doesn't raise a warning if used in a way that
+ would also be valid for a FrozenSet, e.g. iteration).
+ """
+
+ __slots__ = ("mapping",)
+
+ def _warn(self) -> None:
+ emit_user_level_warning(
+ "The return type of `Dataset.dims` will be changed to return a set of dimension names in future, "
+ "in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, "
+ "please use `Dataset.sizes`.",
+ FutureWarning,
+ )
+
+ def __getitem__(self, key: K) -> V:
+ self._warn()
+ return super().__getitem__(key)
+
+ @overload
+ def get(self, key: K, /) -> V | None:
+ ...
+
+ @overload
+ def get(self, key: K, /, default: V | T) -> V | T:
+ ...
+
+ def get(self, key: K, default: T | None = None) -> V | T | None:
+ self._warn()
+ return super().get(key, default)
+
+ def keys(self) -> KeysView[K]:
+ self._warn()
+ return super().keys()
+
+ def items(self) -> ItemsView[K, V]:
+ self._warn()
+ return super().items()
+
+ def values(self) -> ValuesView[V]:
+ self._warn()
+ return super().values()
+
+
class HybridMappingProxy(Mapping[K, V]):
"""Implements the Mapping interface. Uses the wrapped mapping for item lookup
and a separate wrapped keys collection for iteration.
diff --git a/xarray/core/variable.py b/xarray/core/variable.py
index db109a40454..3add7a1441e 100644
--- a/xarray/core/variable.py
+++ b/xarray/core/variable.py
@@ -46,7 +46,7 @@
is_duck_array,
maybe_coerce_to_str,
)
-from xarray.namedarray.core import NamedArray
+from xarray.namedarray.core import NamedArray, _raise_if_any_duplicate_dimensions
NON_NUMPY_SUPPORTED_ARRAY_TYPES = (
indexing.ExplicitlyIndexed,
@@ -1541,15 +1541,15 @@ def stack(self, dimensions=None, **dimensions_kwargs):
result = result._stack_once(dims, new_dim)
return result
- def _unstack_once_full(self, dims: Mapping[Any, int], old_dim: Hashable) -> Self:
+ def _unstack_once_full(self, dim: Mapping[Any, int], old_dim: Hashable) -> Self:
"""
Unstacks the variable without needing an index.
Unlike `_unstack_once`, this function requires the existing dimension to
contain the full product of the new dimensions.
"""
- new_dim_names = tuple(dims.keys())
- new_dim_sizes = tuple(dims.values())
+ new_dim_names = tuple(dim.keys())
+ new_dim_sizes = tuple(dim.values())
if old_dim not in self.dims:
raise ValueError(f"invalid existing dimension: {old_dim}")
@@ -2063,6 +2063,7 @@ def rank(self, dim, pct=False):
--------
Dataset.rank, DataArray.rank
"""
+ # This could / should arguably be implemented at the DataArray & Dataset level
if not OPTIONS["use_bottleneck"]:
raise RuntimeError(
"rank requires bottleneck to be enabled."
@@ -2071,24 +2072,20 @@ def rank(self, dim, pct=False):
import bottleneck as bn
- data = self.data
-
- if is_duck_dask_array(data):
- raise TypeError(
- "rank does not work for arrays stored as dask "
- "arrays. Load the data via .compute() or .load() "
- "prior to calling this method."
- )
- elif not isinstance(data, np.ndarray):
- raise TypeError(f"rank is not implemented for {type(data)} objects.")
-
- axis = self.get_axis_num(dim)
func = bn.nanrankdata if self.dtype.kind == "f" else bn.rankdata
- ranked = func(data, axis=axis)
+ ranked = xr.apply_ufunc(
+ func,
+ self,
+ input_core_dims=[[dim]],
+ output_core_dims=[[dim]],
+ dask="parallelized",
+ kwargs=dict(axis=-1),
+ ).transpose(*self.dims)
+
if pct:
- count = np.sum(~np.isnan(data), axis=axis, keepdims=True)
+ count = self.notnull().sum(dim)
ranked /= count
- return Variable(self.dims, ranked)
+ return ranked
def rolling_window(
self, dim, window, window_dim, center=False, fill_value=dtypes.NA
@@ -2599,7 +2596,7 @@ def _as_sparse(self, sparse_format=_default, fill_value=_default) -> Variable:
"""
Use sparse-array as backend.
"""
- from xarray.namedarray.utils import _default as _default_named
+ from xarray.namedarray._typing import _default as _default_named
if sparse_format is _default:
sparse_format = _default_named
@@ -2879,11 +2876,8 @@ def _unified_dims(variables):
all_dims = {}
for var in variables:
var_dims = var.dims
- if len(set(var_dims)) < len(var_dims):
- raise ValueError(
- "broadcasting cannot handle duplicate "
- f"dimensions: {list(var_dims)!r}"
- )
+ _raise_if_any_duplicate_dimensions(var_dims, err_context="Broadcasting")
+
for d, s in zip(var_dims, var.shape):
if d not in all_dims:
all_dims[d] = s
diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py
index 28740a99020..53ff6db5f28 100644
--- a/xarray/core/weighted.py
+++ b/xarray/core/weighted.py
@@ -228,7 +228,7 @@ def _reduce(
# `dot` does not broadcast arrays, so this avoids creating a large
# DataArray (if `weights` has additional dimensions)
- return dot(da, weights, dims=dim)
+ return dot(da, weights, dim=dim)
def _sum_of_weights(self, da: DataArray, dim: Dims = None) -> DataArray:
"""Calculate the sum of weights, accounting for missing values"""
diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py
index e205c4d4efe..b5c320e0b96 100644
--- a/xarray/namedarray/_array_api.py
+++ b/xarray/namedarray/_array_api.py
@@ -7,7 +7,11 @@
import numpy as np
from xarray.namedarray._typing import (
+ Default,
_arrayapi,
+ _Axis,
+ _default,
+ _Dim,
_DType,
_ScalarType,
_ShapeType,
@@ -144,3 +148,51 @@ def real(
xp = _get_data_namespace(x)
out = x._new(data=xp.real(x._data))
return out
+
+
+# %% Manipulation functions
+def expand_dims(
+ x: NamedArray[Any, _DType],
+ /,
+ *,
+ dim: _Dim | Default = _default,
+ axis: _Axis = 0,
+) -> NamedArray[Any, _DType]:
+ """
+ Expands the shape of an array by inserting a new dimension of size one at the
+ position specified by dims.
+
+ Parameters
+ ----------
+ x :
+ Array to expand.
+ dim :
+ Dimension name. New dimension will be stored in the axis position.
+ axis :
+ (Not recommended) Axis position (zero-based). Default is 0.
+
+ Returns
+ -------
+ out :
+ An expanded output array having the same data type as x.
+
+ Examples
+ --------
+ >>> x = NamedArray(("x", "y"), nxp.asarray([[1.0, 2.0], [3.0, 4.0]]))
+ >>> expand_dims(x)
+
+ Array([[[1., 2.],
+ [3., 4.]]], dtype=float64)
+ >>> expand_dims(x, dim="z")
+
+ Array([[[1., 2.],
+ [3., 4.]]], dtype=float64)
+ """
+ xp = _get_data_namespace(x)
+ dims = x.dims
+ if dim is _default:
+ dim = f"dim_{len(dims)}"
+ d = list(dims)
+ d.insert(axis, dim)
+ out = x._new(dims=tuple(d), data=xp.expand_dims(x._data, axis=axis))
+ return out
diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py
index 0b972e19539..37832daca58 100644
--- a/xarray/namedarray/_typing.py
+++ b/xarray/namedarray/_typing.py
@@ -1,10 +1,12 @@
from __future__ import annotations
from collections.abc import Hashable, Iterable, Mapping, Sequence
+from enum import Enum
from types import ModuleType
from typing import (
Any,
Callable,
+ Final,
Protocol,
SupportsIndex,
TypeVar,
@@ -15,11 +17,19 @@
import numpy as np
+
+# Singleton type, as per https://github.com/python/typing/pull/240
+class Default(Enum):
+ token: Final = 0
+
+
+_default = Default.token
+
# https://stackoverflow.com/questions/74633074/how-to-type-hint-a-generic-numpy-array
_T = TypeVar("_T")
_T_co = TypeVar("_T_co", covariant=True)
-
+_dtype = np.dtype
_DType = TypeVar("_DType", bound=np.dtype[Any])
_DType_co = TypeVar("_DType_co", covariant=True, bound=np.dtype[Any])
# A subset of `npt.DTypeLike` that can be parametrized w.r.t. `np.generic`
@@ -49,15 +59,26 @@ def dtype(self) -> _DType_co:
_ShapeType = TypeVar("_ShapeType", bound=Any)
_ShapeType_co = TypeVar("_ShapeType_co", bound=Any, covariant=True)
+_Axis = int
+_Axes = tuple[_Axis, ...]
+_AxisLike = Union[_Axis, _Axes]
+
_Chunks = tuple[_Shape, ...]
_Dim = Hashable
_Dims = tuple[_Dim, ...]
_DimsLike = Union[str, Iterable[_Dim]]
-_AttrsLike = Union[Mapping[Any, Any], None]
-_dtype = np.dtype
+# https://data-apis.org/array-api/latest/API_specification/indexing.html
+# TODO: np.array_api was bugged and didn't allow (None,), but should!
+# https://github.com/numpy/numpy/pull/25022
+# https://github.com/data-apis/array-api/pull/674
+_IndexKey = Union[int, slice, "ellipsis"]
+_IndexKeys = tuple[Union[_IndexKey], ...] # tuple[Union[_IndexKey, None], ...]
+_IndexKeyLike = Union[_IndexKey, _IndexKeys]
+
+_AttrsLike = Union[Mapping[Any, Any], None]
class _SupportsReal(Protocol[_T_co]):
@@ -99,6 +120,25 @@ class _arrayfunction(
Corresponds to np.ndarray.
"""
+ @overload
+ def __getitem__(
+ self, key: _arrayfunction[Any, Any] | tuple[_arrayfunction[Any, Any], ...], /
+ ) -> _arrayfunction[Any, _DType_co]:
+ ...
+
+ @overload
+ def __getitem__(self, key: _IndexKeyLike, /) -> Any:
+ ...
+
+ def __getitem__(
+ self,
+ key: _IndexKeyLike
+ | _arrayfunction[Any, Any]
+ | tuple[_arrayfunction[Any, Any], ...],
+ /,
+ ) -> _arrayfunction[Any, _DType_co] | Any:
+ ...
+
@overload
def __array__(self, dtype: None = ..., /) -> np.ndarray[Any, _DType_co]:
...
@@ -151,6 +191,14 @@ class _arrayapi(_array[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType
Corresponds to np.ndarray.
"""
+ def __getitem__(
+ self,
+ key: _IndexKeyLike
+ | Any, # TODO: Any should be _arrayapi[Any, _dtype[np.integer]]
+ /,
+ ) -> _arrayapi[Any, Any]:
+ ...
+
def __array_namespace__(self) -> ModuleType:
...
diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py
index d3fcffcfd9e..b9ad27b6679 100644
--- a/xarray/namedarray/core.py
+++ b/xarray/namedarray/core.py
@@ -25,6 +25,7 @@
_arrayapi,
_arrayfunction_or_api,
_chunkedarray,
+ _default,
_dtype,
_DType_co,
_ScalarType_co,
@@ -33,13 +34,14 @@
_SupportsImag,
_SupportsReal,
)
-from xarray.namedarray.utils import _default, is_duck_dask_array, to_0d_object_array
+from xarray.namedarray.utils import is_duck_dask_array, to_0d_object_array
if TYPE_CHECKING:
from numpy.typing import ArrayLike, NDArray
from xarray.core.types import Dims
from xarray.namedarray._typing import (
+ Default,
_AttrsLike,
_Chunks,
_Dim,
@@ -52,7 +54,6 @@
_ShapeType,
duckarray,
)
- from xarray.namedarray.utils import Default
try:
from dask.typing import (
@@ -481,6 +482,15 @@ def _parse_dimensions(self, dims: _DimsLike) -> _Dims:
f"dimensions {dims} must have the same length as the "
f"number of data dimensions, ndim={self.ndim}"
)
+ if len(set(dims)) < len(dims):
+ repeated_dims = set([d for d in dims if dims.count(d) > 1])
+ warnings.warn(
+ f"Duplicate dimension names present: dimensions {repeated_dims} appear more than once in dims={dims}. "
+ "We do not yet support duplicate dimension names, but we do allow initial construction of the object. "
+ "We recommend you rename the dims immediately to become distinct, as most xarray functionality is likely to fail silently if you do not. "
+ "To rename the dimensions you will need to set the ``.dims`` attribute of each variable, ``e.g. var.dims=('x0', 'x1')``.",
+ UserWarning,
+ )
return dims
@property
@@ -651,6 +661,7 @@ def get_axis_num(self, dim: Hashable | Iterable[Hashable]) -> int | tuple[int, .
return self._get_axis_num(dim)
def _get_axis_num(self: Any, dim: Hashable) -> int:
+ _raise_if_any_duplicate_dimensions(self.dims)
try:
return self.dims.index(dim) # type: ignore[no-any-return]
except ValueError:
@@ -846,3 +857,13 @@ def _to_dense(self) -> NamedArray[Any, _DType_co]:
_NamedArray = NamedArray[Any, np.dtype[_ScalarType_co]]
+
+
+def _raise_if_any_duplicate_dimensions(
+ dims: _Dims, err_context: str = "This function"
+) -> None:
+ if len(set(dims)) < len(dims):
+ repeated_dims = set([d for d in dims if dims.count(d) > 1])
+ raise ValueError(
+ f"{err_context} cannot handle duplicate dimensions, but dimensions {repeated_dims} appear more than once on this object's dims: {dims}"
+ )
diff --git a/xarray/namedarray/utils.py b/xarray/namedarray/utils.py
index 03eb0134231..4bd20931189 100644
--- a/xarray/namedarray/utils.py
+++ b/xarray/namedarray/utils.py
@@ -2,12 +2,7 @@
import sys
from collections.abc import Hashable
-from enum import Enum
-from typing import (
- TYPE_CHECKING,
- Any,
- Final,
-)
+from typing import TYPE_CHECKING, Any
import numpy as np
@@ -31,14 +26,6 @@
DaskCollection: Any = NDArray # type: ignore
-# Singleton type, as per https://github.com/python/typing/pull/240
-class Default(Enum):
- token: Final = 0
-
-
-_default = Default.token
-
-
def module_available(module: str) -> bool:
"""Checks whether a module is installed without importing it.
diff --git a/xarray/plot/dataarray_plot.py b/xarray/plot/dataarray_plot.py
index 61f2014fbc3..aebc3c2bac1 100644
--- a/xarray/plot/dataarray_plot.py
+++ b/xarray/plot/dataarray_plot.py
@@ -27,6 +27,7 @@
_rescale_imshow_rgb,
_resolve_intervals_1dplot,
_resolve_intervals_2dplot,
+ _set_concise_date,
_update_axes,
get_axis,
label_from_attrs,
@@ -525,14 +526,8 @@ def line(
assert hueplt is not None
ax.legend(handles=primitive, labels=list(hueplt.to_numpy()), title=hue_label)
- # Rotate dates on xlabels
- # Do this without calling autofmt_xdate so that x-axes ticks
- # on other subplots (if any) are not deleted.
- # https://stackoverflow.com/questions/17430105/autofmt-xdate-deletes-x-axis-labels-of-all-subplots
if np.issubdtype(xplt.dtype, np.datetime64):
- for xlabels in ax.get_xticklabels():
- xlabels.set_rotation(30)
- xlabels.set_horizontalalignment("right")
+ _set_concise_date(ax, axis="x")
_update_axes(ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim)
@@ -949,6 +944,12 @@ def newplotfunc(
if plotfunc.__name__ == "scatter":
size_ = kwargs.pop("_size", markersize)
size_r = _MARKERSIZE_RANGE
+
+ # Remove any nulls, .where(m, drop=True) doesn't work when m is
+ # a dask array, so load the array to memory.
+ # It will have to be loaded to memory at some point anyway:
+ darray = darray.load()
+ darray = darray.where(darray.notnull(), drop=True)
else:
size_ = kwargs.pop("_size", linewidth)
size_r = _LINEWIDTH_RANGE
@@ -1087,14 +1088,12 @@ def _add_labels(
add_labels: bool | Iterable[bool],
darrays: Iterable[DataArray | None],
suffixes: Iterable[str],
- rotate_labels: Iterable[bool],
ax: Axes,
) -> None:
"""Set x, y, z labels."""
add_labels = [add_labels] * 3 if isinstance(add_labels, bool) else add_labels
- for axis, add_label, darray, suffix, rotate_label in zip(
- ("x", "y", "z"), add_labels, darrays, suffixes, rotate_labels
- ):
+ axes: tuple[Literal["x", "y", "z"], ...] = ("x", "y", "z")
+ for axis, add_label, darray, suffix in zip(axes, add_labels, darrays, suffixes):
if darray is None:
continue
@@ -1103,14 +1102,8 @@ def _add_labels(
if label is not None:
getattr(ax, f"set_{axis}label")(label)
- if rotate_label and np.issubdtype(darray.dtype, np.datetime64):
- # Rotate dates on xlabels
- # Do this without calling autofmt_xdate so that x-axes ticks
- # on other subplots (if any) are not deleted.
- # https://stackoverflow.com/questions/17430105/autofmt-xdate-deletes-x-axis-labels-of-all-subplots
- for labels in getattr(ax, f"get_{axis}ticklabels")():
- labels.set_rotation(30)
- labels.set_horizontalalignment("right")
+ if np.issubdtype(darray.dtype, np.datetime64):
+ _set_concise_date(ax, axis=axis)
@overload
@@ -1265,7 +1258,7 @@ def scatter(
kwargs.update(s=sizeplt.to_numpy().ravel())
plts_or_none = (xplt, yplt, zplt)
- _add_labels(add_labels, plts_or_none, ("", "", ""), (True, False, False), ax)
+ _add_labels(add_labels, plts_or_none, ("", "", ""), ax)
xplt_np = None if xplt is None else xplt.to_numpy().ravel()
yplt_np = None if yplt is None else yplt.to_numpy().ravel()
@@ -1653,14 +1646,8 @@ def newplotfunc(
ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim
)
- # Rotate dates on xlabels
- # Do this without calling autofmt_xdate so that x-axes ticks
- # on other subplots (if any) are not deleted.
- # https://stackoverflow.com/questions/17430105/autofmt-xdate-deletes-x-axis-labels-of-all-subplots
if np.issubdtype(xplt.dtype, np.datetime64):
- for xlabels in ax.get_xticklabels():
- xlabels.set_rotation(30)
- xlabels.set_horizontalalignment("right")
+ _set_concise_date(ax, "x")
return primitive
diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py
index 5694acc06e8..903780b1137 100644
--- a/xarray/plot/utils.py
+++ b/xarray/plot/utils.py
@@ -6,7 +6,7 @@
from collections.abc import Hashable, Iterable, Mapping, MutableMapping, Sequence
from datetime import datetime
from inspect import getfullargspec
-from typing import TYPE_CHECKING, Any, Callable, overload
+from typing import TYPE_CHECKING, Any, Callable, Literal, overload
import numpy as np
import pandas as pd
@@ -1827,3 +1827,27 @@ def _guess_coords_to_plot(
_assert_valid_xy(darray, dim, k)
return coords_to_plot
+
+
+def _set_concise_date(ax: Axes, axis: Literal["x", "y", "z"] = "x") -> None:
+ """
+ Use ConciseDateFormatter which is meant to improve the
+ strings chosen for the ticklabels, and to minimize the
+ strings used in those tick labels as much as possible.
+
+ https://matplotlib.org/stable/gallery/ticks/date_concise_formatter.html
+
+ Parameters
+ ----------
+ ax : Axes
+ Figure axes.
+ axis : Literal["x", "y", "z"], optional
+ Which axis to make concise. The default is "x".
+ """
+ import matplotlib.dates as mdates
+
+ locator = mdates.AutoDateLocator()
+ formatter = mdates.ConciseDateFormatter(locator)
+ _axis = getattr(ax, f"{axis}axis")
+ _axis.set_major_locator(locator)
+ _axis.set_major_formatter(formatter)
diff --git a/xarray/testing/__init__.py b/xarray/testing/__init__.py
new file mode 100644
index 00000000000..ab2f8ba4357
--- /dev/null
+++ b/xarray/testing/__init__.py
@@ -0,0 +1,23 @@
+from xarray.testing.assertions import ( # noqa: F401
+ _assert_dataarray_invariants,
+ _assert_dataset_invariants,
+ _assert_indexes_invariants_checks,
+ _assert_internal_invariants,
+ _assert_variable_invariants,
+ _data_allclose_or_equiv,
+ assert_allclose,
+ assert_chunks_equal,
+ assert_duckarray_allclose,
+ assert_duckarray_equal,
+ assert_equal,
+ assert_identical,
+)
+
+__all__ = [
+ "assert_allclose",
+ "assert_chunks_equal",
+ "assert_duckarray_equal",
+ "assert_duckarray_allclose",
+ "assert_equal",
+ "assert_identical",
+]
diff --git a/xarray/testing.py b/xarray/testing/assertions.py
similarity index 98%
rename from xarray/testing.py
rename to xarray/testing/assertions.py
index 0837b562668..faa595a64b6 100644
--- a/xarray/testing.py
+++ b/xarray/testing/assertions.py
@@ -14,15 +14,6 @@
from xarray.core.indexes import Index, PandasIndex, PandasMultiIndex, default_indexes
from xarray.core.variable import IndexVariable, Variable
-__all__ = (
- "assert_allclose",
- "assert_chunks_equal",
- "assert_duckarray_equal",
- "assert_duckarray_allclose",
- "assert_equal",
- "assert_identical",
-)
-
def ensure_warnings(func):
# sometimes tests elevate warnings to errors
diff --git a/xarray/testing/strategies.py b/xarray/testing/strategies.py
new file mode 100644
index 00000000000..d08cbc0b584
--- /dev/null
+++ b/xarray/testing/strategies.py
@@ -0,0 +1,447 @@
+from collections.abc import Hashable, Iterable, Mapping, Sequence
+from typing import TYPE_CHECKING, Any, Protocol, Union, overload
+
+try:
+ import hypothesis.strategies as st
+except ImportError as e:
+ raise ImportError(
+ "`xarray.testing.strategies` requires `hypothesis` to be installed."
+ ) from e
+
+import hypothesis.extra.numpy as npst
+import numpy as np
+from hypothesis.errors import InvalidArgument
+
+import xarray as xr
+from xarray.core.types import T_DuckArray
+
+if TYPE_CHECKING:
+ from xarray.core.types import _DTypeLikeNested, _ShapeLike
+
+
+__all__ = [
+ "supported_dtypes",
+ "names",
+ "dimension_names",
+ "dimension_sizes",
+ "attrs",
+ "variables",
+ "unique_subset_of",
+]
+
+
+class ArrayStrategyFn(Protocol[T_DuckArray]):
+ def __call__(
+ self,
+ *,
+ shape: "_ShapeLike",
+ dtype: "_DTypeLikeNested",
+ ) -> st.SearchStrategy[T_DuckArray]:
+ ...
+
+
+def supported_dtypes() -> st.SearchStrategy[np.dtype]:
+ """
+ Generates only those numpy dtypes which xarray can handle.
+
+ Use instead of hypothesis.extra.numpy.scalar_dtypes in order to exclude weirder dtypes such as unicode, byte_string, array, or nested dtypes.
+ Also excludes datetimes, which dodges bugs with pandas non-nanosecond datetime overflows.
+
+ Requires the hypothesis package to be installed.
+
+ See Also
+ --------
+ :ref:`testing.hypothesis`_
+ """
+ # TODO should this be exposed publicly?
+ # We should at least decide what the set of numpy dtypes that xarray officially supports is.
+ return (
+ npst.integer_dtypes()
+ | npst.unsigned_integer_dtypes()
+ | npst.floating_dtypes()
+ | npst.complex_number_dtypes()
+ )
+
+
+# TODO Generalize to all valid unicode characters once formatting bugs in xarray's reprs are fixed + docs can handle it.
+_readable_characters = st.characters(
+ categories=["L", "N"], max_codepoint=0x017F
+) # only use characters within the "Latin Extended-A" subset of unicode
+
+
+def names() -> st.SearchStrategy[str]:
+ """
+ Generates arbitrary string names for dimensions / variables.
+
+ Requires the hypothesis package to be installed.
+
+ See Also
+ --------
+ :ref:`testing.hypothesis`_
+ """
+ return st.text(
+ _readable_characters,
+ min_size=1,
+ max_size=5,
+ )
+
+
+def dimension_names(
+ *,
+ min_dims: int = 0,
+ max_dims: int = 3,
+) -> st.SearchStrategy[list[Hashable]]:
+ """
+ Generates an arbitrary list of valid dimension names.
+
+ Requires the hypothesis package to be installed.
+
+ Parameters
+ ----------
+ min_dims
+ Minimum number of dimensions in generated list.
+ max_dims
+ Maximum number of dimensions in generated list.
+ """
+
+ return st.lists(
+ elements=names(),
+ min_size=min_dims,
+ max_size=max_dims,
+ unique=True,
+ )
+
+
+def dimension_sizes(
+ *,
+ dim_names: st.SearchStrategy[Hashable] = names(),
+ min_dims: int = 0,
+ max_dims: int = 3,
+ min_side: int = 1,
+ max_side: Union[int, None] = None,
+) -> st.SearchStrategy[Mapping[Hashable, int]]:
+ """
+ Generates an arbitrary mapping from dimension names to lengths.
+
+ Requires the hypothesis package to be installed.
+
+ Parameters
+ ----------
+ dim_names: strategy generating strings, optional
+ Strategy for generating dimension names.
+ Defaults to the `names` strategy.
+ min_dims: int, optional
+ Minimum number of dimensions in generated list.
+ Default is 1.
+ max_dims: int, optional
+ Maximum number of dimensions in generated list.
+ Default is 3.
+ min_side: int, optional
+ Minimum size of a dimension.
+ Default is 1.
+ max_side: int, optional
+ Minimum size of a dimension.
+ Default is `min_length` + 5.
+
+ See Also
+ --------
+ :ref:`testing.hypothesis`_
+ """
+
+ if max_side is None:
+ max_side = min_side + 3
+
+ return st.dictionaries(
+ keys=dim_names,
+ values=st.integers(min_value=min_side, max_value=max_side),
+ min_size=min_dims,
+ max_size=max_dims,
+ )
+
+
+_readable_strings = st.text(
+ _readable_characters,
+ max_size=5,
+)
+_attr_keys = _readable_strings
+_small_arrays = npst.arrays(
+ shape=npst.array_shapes(
+ max_side=2,
+ max_dims=2,
+ ),
+ dtype=npst.scalar_dtypes(),
+)
+_attr_values = st.none() | st.booleans() | _readable_strings | _small_arrays
+
+
+def attrs() -> st.SearchStrategy[Mapping[Hashable, Any]]:
+ """
+ Generates arbitrary valid attributes dictionaries for xarray objects.
+
+ The generated dictionaries can potentially be recursive.
+
+ Requires the hypothesis package to be installed.
+
+ See Also
+ --------
+ :ref:`testing.hypothesis`_
+ """
+ return st.recursive(
+ st.dictionaries(_attr_keys, _attr_values),
+ lambda children: st.dictionaries(_attr_keys, children),
+ max_leaves=3,
+ )
+
+
+@st.composite
+def variables(
+ draw: st.DrawFn,
+ *,
+ array_strategy_fn: Union[ArrayStrategyFn, None] = None,
+ dims: Union[
+ st.SearchStrategy[Union[Sequence[Hashable], Mapping[Hashable, int]]],
+ None,
+ ] = None,
+ dtype: st.SearchStrategy[np.dtype] = supported_dtypes(),
+ attrs: st.SearchStrategy[Mapping] = attrs(),
+) -> xr.Variable:
+ """
+ Generates arbitrary xarray.Variable objects.
+
+ Follows the basic signature of the xarray.Variable constructor, but allows passing alternative strategies to
+ generate either numpy-like array data or dimensions. Also allows specifying the shape or dtype of the wrapped array
+ up front.
+
+ Passing nothing will generate a completely arbitrary Variable (containing a numpy array).
+
+ Requires the hypothesis package to be installed.
+
+ Parameters
+ ----------
+ array_strategy_fn: Callable which returns a strategy generating array-likes, optional
+ Callable must only accept shape and dtype kwargs, and must generate results consistent with its input.
+ If not passed the default is to generate a small numpy array with one of the supported_dtypes.
+ dims: Strategy for generating the dimensions, optional
+ Can either be a strategy for generating a sequence of string dimension names,
+ or a strategy for generating a mapping of string dimension names to integer lengths along each dimension.
+ If provided as a mapping the array shape will be passed to array_strategy_fn.
+ Default is to generate arbitrary dimension names for each axis in data.
+ dtype: Strategy which generates np.dtype objects, optional
+ Will be passed in to array_strategy_fn.
+ Default is to generate any scalar dtype using supported_dtypes.
+ Be aware that this default set of dtypes includes some not strictly allowed by the array API standard.
+ attrs: Strategy which generates dicts, optional
+ Default is to generate a nested attributes dictionary containing arbitrary strings, booleans, integers, Nones,
+ and numpy arrays.
+
+ Returns
+ -------
+ variable_strategy
+ Strategy for generating xarray.Variable objects.
+
+ Raises
+ ------
+ ValueError
+ If a custom array_strategy_fn returns a strategy which generates an example array inconsistent with the shape
+ & dtype input passed to it.
+
+ Examples
+ --------
+ Generate completely arbitrary Variable objects backed by a numpy array:
+
+ >>> variables().example() # doctest: +SKIP
+
+ array([43506, -16, -151], dtype=int32)
+ >>> variables().example() # doctest: +SKIP
+
+ array([[[-10000000., -10000000.],
+ [-10000000., -10000000.]],
+ [[-10000000., -10000000.],
+ [ 0., -10000000.]],
+ [[ 0., -10000000.],
+ [-10000000., inf]],
+ [[ -0., -10000000.],
+ [-10000000., -0.]]], dtype=float32)
+ Attributes:
+ śřĴ: {'ĉ': {'iĥf': array([-30117, -1740], dtype=int16)}}
+
+ Generate only Variable objects with certain dimension names:
+
+ >>> variables(dims=st.just(["a", "b"])).example() # doctest: +SKIP
+
+ array([[ 248, 4294967295, 4294967295],
+ [2412855555, 3514117556, 4294967295],
+ [ 111, 4294967295, 4294967295],
+ [4294967295, 1084434988, 51688],
+ [ 47714, 252, 11207]], dtype=uint32)
+
+ Generate only Variable objects with certain dimension names and lengths:
+
+ >>> variables(dims=st.just({"a": 2, "b": 1})).example() # doctest: +SKIP
+
+ array([[-1.00000000e+007+3.40282347e+038j],
+ [-2.75034266e-225+2.22507386e-311j]])
+
+ See Also
+ --------
+ :ref:`testing.hypothesis`_
+ """
+
+ if not isinstance(dims, st.SearchStrategy) and dims is not None:
+ raise InvalidArgument(
+ f"dims must be provided as a hypothesis.strategies.SearchStrategy object (or None), but got type {type(dims)}. "
+ "To specify fixed contents, use hypothesis.strategies.just()."
+ )
+ if not isinstance(dtype, st.SearchStrategy) and dtype is not None:
+ raise InvalidArgument(
+ f"dtype must be provided as a hypothesis.strategies.SearchStrategy object (or None), but got type {type(dtype)}. "
+ "To specify fixed contents, use hypothesis.strategies.just()."
+ )
+ if not isinstance(attrs, st.SearchStrategy) and attrs is not None:
+ raise InvalidArgument(
+ f"attrs must be provided as a hypothesis.strategies.SearchStrategy object (or None), but got type {type(attrs)}. "
+ "To specify fixed contents, use hypothesis.strategies.just()."
+ )
+
+ _array_strategy_fn: ArrayStrategyFn
+ if array_strategy_fn is None:
+ # For some reason if I move the default value to the function signature definition mypy incorrectly says the ignore is no longer necessary, making it impossible to satisfy mypy
+ _array_strategy_fn = npst.arrays # type: ignore[assignment] # npst.arrays has extra kwargs that we aren't using later
+ elif not callable(array_strategy_fn):
+ raise InvalidArgument(
+ "array_strategy_fn must be a Callable that accepts the kwargs dtype and shape and returns a hypothesis "
+ "strategy which generates corresponding array-like objects."
+ )
+ else:
+ _array_strategy_fn = (
+ array_strategy_fn # satisfy mypy that this new variable cannot be None
+ )
+
+ _dtype = draw(dtype)
+
+ if dims is not None:
+ # generate dims first then draw data to match
+ _dims = draw(dims)
+ if isinstance(_dims, Sequence):
+ dim_names = list(_dims)
+ valid_shapes = npst.array_shapes(min_dims=len(_dims), max_dims=len(_dims))
+ _shape = draw(valid_shapes)
+ array_strategy = _array_strategy_fn(shape=_shape, dtype=_dtype)
+ elif isinstance(_dims, (Mapping, dict)):
+ # should be a mapping of form {dim_names: lengths}
+ dim_names, _shape = list(_dims.keys()), tuple(_dims.values())
+ array_strategy = _array_strategy_fn(shape=_shape, dtype=_dtype)
+ else:
+ raise InvalidArgument(
+ f"Invalid type returned by dims strategy - drew an object of type {type(dims)}"
+ )
+ else:
+ # nothing provided, so generate everything consistently
+ # We still generate the shape first here just so that we always pass shape to array_strategy_fn
+ _shape = draw(npst.array_shapes())
+ array_strategy = _array_strategy_fn(shape=_shape, dtype=_dtype)
+ dim_names = draw(dimension_names(min_dims=len(_shape), max_dims=len(_shape)))
+
+ _data = draw(array_strategy)
+
+ if _data.shape != _shape:
+ raise ValueError(
+ "array_strategy_fn returned an array object with a different shape than it was passed."
+ f"Passed {_shape}, but returned {_data.shape}."
+ "Please either specify a consistent shape via the dims kwarg or ensure the array_strategy_fn callable "
+ "obeys the shape argument passed to it."
+ )
+ if _data.dtype != _dtype:
+ raise ValueError(
+ "array_strategy_fn returned an array object with a different dtype than it was passed."
+ f"Passed {_dtype}, but returned {_data.dtype}"
+ "Please either specify a consistent dtype via the dtype kwarg or ensure the array_strategy_fn callable "
+ "obeys the dtype argument passed to it."
+ )
+
+ return xr.Variable(dims=dim_names, data=_data, attrs=draw(attrs))
+
+
+@overload
+def unique_subset_of(
+ objs: Sequence[Hashable],
+ *,
+ min_size: int = 0,
+ max_size: Union[int, None] = None,
+) -> st.SearchStrategy[Sequence[Hashable]]:
+ ...
+
+
+@overload
+def unique_subset_of(
+ objs: Mapping[Hashable, Any],
+ *,
+ min_size: int = 0,
+ max_size: Union[int, None] = None,
+) -> st.SearchStrategy[Mapping[Hashable, Any]]:
+ ...
+
+
+@st.composite
+def unique_subset_of(
+ draw: st.DrawFn,
+ objs: Union[Sequence[Hashable], Mapping[Hashable, Any]],
+ *,
+ min_size: int = 0,
+ max_size: Union[int, None] = None,
+) -> Union[Sequence[Hashable], Mapping[Hashable, Any]]:
+ """
+ Return a strategy which generates a unique subset of the given objects.
+
+ Each entry in the output subset will be unique (if input was a sequence) or have a unique key (if it was a mapping).
+
+ Requires the hypothesis package to be installed.
+
+ Parameters
+ ----------
+ objs: Union[Sequence[Hashable], Mapping[Hashable, Any]]
+ Objects from which to sample to produce the subset.
+ min_size: int, optional
+ Minimum size of the returned subset. Default is 0.
+ max_size: int, optional
+ Maximum size of the returned subset. Default is the full length of the input.
+ If set to 0 the result will be an empty mapping.
+
+ Returns
+ -------
+ unique_subset_strategy
+ Strategy generating subset of the input.
+
+ Examples
+ --------
+ >>> unique_subset_of({"x": 2, "y": 3}).example() # doctest: +SKIP
+ {'y': 3}
+ >>> unique_subset_of(["x", "y"]).example() # doctest: +SKIP
+ ['x']
+
+ See Also
+ --------
+ :ref:`testing.hypothesis`_
+ """
+ if not isinstance(objs, Iterable):
+ raise TypeError(
+ f"Object to sample from must be an Iterable or a Mapping, but received type {type(objs)}"
+ )
+
+ if len(objs) == 0:
+ raise ValueError("Can't sample from a length-zero object.")
+
+ keys = list(objs.keys()) if isinstance(objs, Mapping) else objs
+
+ subset_keys = draw(
+ st.lists(
+ st.sampled_from(keys),
+ unique=True,
+ min_size=min_size,
+ max_size=max_size,
+ )
+ )
+
+ return (
+ {k: objs[k] for k in subset_keys} if isinstance(objs, Mapping) else subset_keys
+ )
diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py
index fec695f83d7..b3a31b28016 100644
--- a/xarray/tests/__init__.py
+++ b/xarray/tests/__init__.py
@@ -53,7 +53,8 @@ def _importorskip(
mod = importlib.import_module(modname)
has = True
if minversion is not None:
- if Version(mod.__version__) < Version(minversion):
+ v = getattr(mod, "__version__", "999")
+ if Version(v) < Version(minversion):
raise ImportError("Minimum version not satisfied")
except ImportError:
has = False
@@ -63,7 +64,14 @@ def _importorskip(
has_matplotlib, requires_matplotlib = _importorskip("matplotlib")
has_scipy, requires_scipy = _importorskip("scipy")
-has_pydap, requires_pydap = _importorskip("pydap.client")
+with warnings.catch_warnings():
+ warnings.filterwarnings(
+ "ignore",
+ message="'cgi' is deprecated and slated for removal in Python 3.13",
+ category=DeprecationWarning,
+ )
+
+ has_pydap, requires_pydap = _importorskip("pydap.client")
has_netCDF4, requires_netCDF4 = _importorskip("netCDF4")
has_h5netcdf, requires_h5netcdf = _importorskip("h5netcdf")
has_pynio, requires_pynio = _importorskip("Nio")
@@ -89,11 +97,16 @@ def _importorskip(
requires_scipy_or_netCDF4 = pytest.mark.skipif(
not has_scipy_or_netCDF4, reason="requires scipy or netCDF4"
)
+has_numbagg_or_bottleneck = has_numbagg or has_bottleneck
+requires_numbagg_or_bottleneck = pytest.mark.skipif(
+ not has_scipy_or_netCDF4, reason="requires scipy or netCDF4"
+)
# _importorskip does not work for development versions
has_pandas_version_two = Version(pd.__version__).major >= 2
requires_pandas_version_two = pytest.mark.skipif(
not has_pandas_version_two, reason="requires pandas 2.0.0"
)
+has_numpy_array_api, requires_numpy_array_api = _importorskip("numpy", "1.26.0")
has_h5netcdf_ros3 = _importorskip("h5netcdf", "1.3.0")
requires_h5netcdf_ros3 = pytest.mark.skipif(
not has_h5netcdf_ros3[0], reason="requires h5netcdf 1.3.0"
@@ -210,11 +223,18 @@ def source_ndarray(array):
return base
+def format_record(record) -> str:
+ """Format warning record like `FutureWarning('Function will be deprecated...')`"""
+ return f"{str(record.category)[8:-2]}('{record.message}'))"
+
+
@contextmanager
def assert_no_warnings():
with warnings.catch_warnings(record=True) as record:
yield record
- assert len(record) == 0, "got unexpected warning(s)"
+ assert (
+ len(record) == 0
+ ), f"Got {len(record)} unexpected warning(s): {[format_record(r) for r in record]}"
# Internal versions of xarray's test functions that validate additional
diff --git a/xarray/tests/conftest.py b/xarray/tests/conftest.py
index 6a8cf008f9f..f153c2f4dc0 100644
--- a/xarray/tests/conftest.py
+++ b/xarray/tests/conftest.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import numpy as np
import pandas as pd
import pytest
@@ -77,3 +79,44 @@ def da(request, backend):
return da
else:
raise ValueError
+
+
+@pytest.fixture(params=[Dataset, DataArray])
+def type(request):
+ return request.param
+
+
+@pytest.fixture(params=[1])
+def d(request, backend, type) -> DataArray | Dataset:
+ """
+ For tests which can test either a DataArray or a Dataset.
+ """
+ result: DataArray | Dataset
+ if request.param == 1:
+ ds = Dataset(
+ dict(
+ a=(["x", "z"], np.arange(24).reshape(2, 12)),
+ b=(["y", "z"], np.arange(100, 136).reshape(3, 12).astype(np.float64)),
+ ),
+ dict(
+ x=("x", np.linspace(0, 1.0, 2)),
+ y=range(3),
+ z=("z", pd.date_range("2000-01-01", periods=12)),
+ w=("x", ["a", "b"]),
+ ),
+ )
+ if type == DataArray:
+ result = ds["a"].assign_coords(w=ds.coords["w"])
+ elif type == Dataset:
+ result = ds
+ else:
+ raise ValueError
+ else:
+ raise ValueError
+
+ if backend == "dask":
+ return result.chunk()
+ elif backend == "numpy":
+ return result
+ else:
+ raise ValueError
diff --git a/xarray/tests/test_testing.py b/xarray/tests/test_assertions.py
similarity index 100%
rename from xarray/tests/test_testing.py
rename to xarray/tests/test_assertions.py
diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py
index 21ac5bc8b3d..2f454292a28 100644
--- a/xarray/tests/test_backends.py
+++ b/xarray/tests/test_backends.py
@@ -865,12 +865,13 @@ def test_roundtrip_empty_vlen_string_array(self) -> None:
assert check_vlen_dtype(original["a"].dtype) == str
with self.roundtrip(original) as actual:
assert_identical(original, actual)
- assert object == actual["a"].dtype
- assert actual["a"].dtype == original["a"].dtype
- # only check metadata for capable backends
- # eg. NETCDF3 based backends do not roundtrip metadata
- if actual["a"].dtype.metadata is not None:
- assert check_vlen_dtype(actual["a"].dtype) == str
+ if np.issubdtype(actual["a"].dtype, object):
+ # only check metadata for capable backends
+ # eg. NETCDF3 based backends do not roundtrip metadata
+ if actual["a"].dtype.metadata is not None:
+ assert check_vlen_dtype(actual["a"].dtype) == str
+ else:
+ assert actual["a"].dtype == np.dtype(" None:
with self.open(tmp_file, group="data/2") as actual2:
assert_identical(data2, actual2)
- def test_encoding_kwarg_vlen_string(self) -> None:
- for input_strings in [[b"foo", b"bar", b"baz"], ["foo", "bar", "baz"]]:
- original = Dataset({"x": input_strings})
- expected = Dataset({"x": ["foo", "bar", "baz"]})
- kwargs = dict(encoding={"x": {"dtype": str}})
- with self.roundtrip(original, save_kwargs=kwargs) as actual:
- assert actual["x"].encoding["dtype"] is str
- assert_identical(actual, expected)
-
- def test_roundtrip_string_with_fill_value_vlen(self) -> None:
+ @pytest.mark.parametrize(
+ "input_strings, is_bytes",
+ [
+ ([b"foo", b"bar", b"baz"], True),
+ (["foo", "bar", "baz"], False),
+ (["foó", "bár", "baź"], False),
+ ],
+ )
+ def test_encoding_kwarg_vlen_string(
+ self, input_strings: list[str], is_bytes: bool
+ ) -> None:
+ original = Dataset({"x": input_strings})
+
+ expected_string = ["foo", "bar", "baz"] if is_bytes else input_strings
+ expected = Dataset({"x": expected_string})
+ kwargs = dict(encoding={"x": {"dtype": str}})
+ with self.roundtrip(original, save_kwargs=kwargs) as actual:
+ assert actual["x"].encoding["dtype"] == " None:
values = np.array(["ab", "cdef", np.nan], dtype=object)
expected = Dataset({"x": ("t", values)})
- # netCDF4-based backends don't support an explicit fillvalue
- # for variable length strings yet.
- # https://github.com/Unidata/netcdf4-python/issues/730
- # https://github.com/h5netcdf/h5netcdf/issues/37
- original = Dataset({"x": ("t", values, {}, {"_FillValue": "XXX"})})
- with pytest.raises(NotImplementedError):
- with self.roundtrip(original) as actual:
- assert_identical(expected, actual)
+ original = Dataset({"x": ("t", values, {}, {"_FillValue": fill_value})})
+ with self.roundtrip(original) as actual:
+ assert_identical(expected, actual)
original = Dataset({"x": ("t", values, {}, {"_FillValue": ""})})
- with pytest.raises(NotImplementedError):
- with self.roundtrip(original) as actual:
- assert_identical(expected, actual)
+ with self.roundtrip(original) as actual:
+ assert_identical(expected, actual)
def test_roundtrip_character_array(self) -> None:
with create_tmp_file() as tmp_file:
@@ -2383,6 +2391,29 @@ def test_append_with_new_variable(self) -> None:
xr.open_dataset(store_target, engine="zarr", **self.version_kwargs),
)
+ def test_append_with_append_dim_no_overwrite(self) -> None:
+ ds, ds_to_append, _ = create_append_test_data()
+ with self.create_zarr_target() as store_target:
+ ds.to_zarr(store_target, mode="w", **self.version_kwargs)
+ original = xr.concat([ds, ds_to_append], dim="time")
+ original2 = xr.concat([original, ds_to_append], dim="time")
+
+ # overwrite a coordinate;
+ # for mode='a-', this will not get written to the store
+ # because it does not have the append_dim as a dim
+ ds_to_append.lon.data[:] = -999
+ ds_to_append.to_zarr(
+ store_target, mode="a-", append_dim="time", **self.version_kwargs
+ )
+ actual = xr.open_dataset(store_target, engine="zarr", **self.version_kwargs)
+ assert_identical(original, actual)
+
+ # by default, mode="a" will overwrite all coordinates.
+ ds_to_append.to_zarr(store_target, append_dim="time", **self.version_kwargs)
+ actual = xr.open_dataset(store_target, engine="zarr", **self.version_kwargs)
+ original2.lon.data[:] = -999
+ assert_identical(original2, actual)
+
@requires_dask
def test_to_zarr_compute_false_roundtrip(self) -> None:
from dask.delayed import Delayed
@@ -2579,7 +2610,7 @@ def setup_and_verify_store(expected=data):
with pytest.raises(
ValueError,
match=re.escape(
- "cannot set region unless mode='a', mode='r+' or mode=None"
+ "cannot set region unless mode='a', mode='a-', mode='r+' or mode=None"
),
):
data.to_zarr(
@@ -2829,6 +2860,43 @@ def test_write_empty(
ls = listdir(os.path.join(store, "test"))
assert set(expected) == set([file for file in ls if file[0] != "."])
+ def test_avoid_excess_metadata_calls(self) -> None:
+ """Test that chunk requests do not trigger redundant metadata requests.
+
+ This test targets logic in backends.zarr.ZarrArrayWrapper, asserting that calls
+ to retrieve chunk data after initialization do not trigger additional
+ metadata requests.
+
+ https://github.com/pydata/xarray/issues/8290
+ """
+
+ import zarr
+
+ ds = xr.Dataset(data_vars={"test": (("Z",), np.array([123]).reshape(1))})
+
+ # The call to retrieve metadata performs a group lookup. We patch Group.__getitem__
+ # so that we can inspect calls to this method - specifically count of calls.
+ # Use of side_effect means that calls are passed through to the original method
+ # rather than a mocked method.
+ Group = zarr.hierarchy.Group
+ with (
+ self.create_zarr_target() as store,
+ patch.object(
+ Group, "__getitem__", side_effect=Group.__getitem__, autospec=True
+ ) as mock,
+ ):
+ ds.to_zarr(store, mode="w")
+
+ # We expect this to request array metadata information, so call_count should be == 1,
+ xrds = xr.open_zarr(store)
+ call_count = mock.call_count
+ assert call_count == 1
+
+ # compute() requests array data, which should not trigger additional metadata requests
+ # we assert that the number of calls has not increased after fetchhing the array
+ xrds.test.compute(scheduler="sync")
+ assert mock.call_count == call_count
+
class ZarrBaseV3(ZarrBase):
zarr_version = 3
@@ -2869,47 +2937,6 @@ def create_zarr_target(self):
yield tmp
-@requires_zarr
-class TestZarrArrayWrapperCalls(TestZarrKVStoreV3):
- def test_avoid_excess_metadata_calls(self) -> None:
- """Test that chunk requests do not trigger redundant metadata requests.
-
- This test targets logic in backends.zarr.ZarrArrayWrapper, asserting that calls
- to retrieve chunk data after initialization do not trigger additional
- metadata requests.
-
- https://github.com/pydata/xarray/issues/8290
- """
-
- import zarr
-
- ds = xr.Dataset(data_vars={"test": (("Z",), np.array([123]).reshape(1))})
-
- # The call to retrieve metadata performs a group lookup. We patch Group.__getitem__
- # so that we can inspect calls to this method - specifically count of calls.
- # Use of side_effect means that calls are passed through to the original method
- # rather than a mocked method.
- Group = zarr.hierarchy.Group
- with (
- self.create_zarr_target() as store,
- patch.object(
- Group, "__getitem__", side_effect=Group.__getitem__, autospec=True
- ) as mock,
- ):
- ds.to_zarr(store, mode="w")
-
- # We expect this to request array metadata information, so call_count should be >= 1,
- # At time of writing, 2 calls are made
- xrds = xr.open_zarr(store)
- call_count = mock.call_count
- assert call_count > 0
-
- # compute() requests array data, which should not trigger additional metadata requests
- # we assert that the number of calls has not increased after fetchhing the array
- xrds.test.compute(scheduler="sync")
- assert mock.call_count == call_count
-
-
@requires_zarr
@requires_fsspec
def test_zarr_storage_options() -> None:
@@ -3457,6 +3484,7 @@ class TestH5NetCDFDataRos3Driver(TestCommon):
"https://www.unidata.ucar.edu/software/netcdf/examples/OMI-Aura_L2-example.nc"
)
+ @pytest.mark.filterwarnings("ignore:Duplicate dimension names")
def test_get_variable_list(self) -> None:
with open_dataset(
self.test_remote_dataset,
@@ -3465,6 +3493,7 @@ def test_get_variable_list(self) -> None:
) as actual:
assert "Temperature" in list(actual)
+ @pytest.mark.filterwarnings("ignore:Duplicate dimension names")
def test_get_variable_list_empty_driver_kwds(self) -> None:
driver_kwds = {
"secret_id": b"",
@@ -5242,6 +5271,16 @@ def test_pickle_open_mfdataset_dataset():
assert_identical(ds, pickle.loads(pickle.dumps(ds)))
+@requires_zarr
+def test_zarr_closing_internal_zip_store():
+ store_name = "tmp.zarr.zip"
+ original_da = DataArray(np.arange(12).reshape((3, 4)))
+ original_da.to_zarr(store_name, mode="w")
+
+ with open_dataarray(store_name, engine="zarr") as loaded_da:
+ assert_identical(original_da, loaded_da)
+
+
@requires_zarr
class TestZarrRegionAuto:
def test_zarr_region_auto_all(self, tmp_path):
@@ -5416,7 +5455,7 @@ def test_zarr_region_append(self, tmp_path):
@requires_zarr
-def test_zarr_region_transpose(tmp_path):
+def test_zarr_region(tmp_path):
x = np.arange(0, 50, 10)
y = np.arange(0, 20, 2)
data = np.ones((5, 10))
@@ -5431,10 +5470,14 @@ def test_zarr_region_transpose(tmp_path):
)
ds.to_zarr(tmp_path / "test.zarr")
- ds_region = 1 + ds.isel(x=[0], y=[0]).transpose()
+ ds_transposed = ds.transpose("y", "x")
+
+ ds_region = 1 + ds_transposed.isel(x=[0], y=[0])
ds_region.to_zarr(
tmp_path / "test.zarr", region={"x": slice(0, 1), "y": slice(0, 1)}
)
+ # Write without region
+ ds_transposed.to_zarr(tmp_path / "test.zarr", mode="r+")
@requires_dask
diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py
index 425673dc40f..68c20c4f51b 100644
--- a/xarray/tests/test_computation.py
+++ b/xarray/tests/test_computation.py
@@ -1378,7 +1378,7 @@ def func(da):
expected = extract(ds)
actual = extract(ds.chunk())
- assert actual.dims == {"lon_new": 3, "lat_new": 6}
+ assert actual.sizes == {"lon_new": 3, "lat_new": 6}
assert_identical(expected.chunk(), actual)
@@ -1775,6 +1775,97 @@ def test_complex_cov() -> None:
assert abs(actual.item()) == 2
+@pytest.mark.parametrize("weighted", [True, False])
+def test_bilinear_cov_corr(weighted: bool) -> None:
+ # Test the bilinear properties of covariance and correlation
+ da = xr.DataArray(
+ np.random.random((3, 21, 4)),
+ coords={"time": pd.date_range("2000-01-01", freq="1D", periods=21)},
+ dims=("a", "time", "x"),
+ )
+ db = xr.DataArray(
+ np.random.random((3, 21, 4)),
+ coords={"time": pd.date_range("2000-01-01", freq="1D", periods=21)},
+ dims=("a", "time", "x"),
+ )
+ dc = xr.DataArray(
+ np.random.random((3, 21, 4)),
+ coords={"time": pd.date_range("2000-01-01", freq="1D", periods=21)},
+ dims=("a", "time", "x"),
+ )
+ if weighted:
+ weights = xr.DataArray(
+ np.abs(np.random.random(4)),
+ dims=("x"),
+ )
+ else:
+ weights = None
+ k = np.random.random(1)[0]
+
+ # Test covariance properties
+ assert_allclose(
+ xr.cov(da + k, db, weights=weights), xr.cov(da, db, weights=weights)
+ )
+ assert_allclose(
+ xr.cov(da, db + k, weights=weights), xr.cov(da, db, weights=weights)
+ )
+ assert_allclose(
+ xr.cov(da + dc, db, weights=weights),
+ xr.cov(da, db, weights=weights) + xr.cov(dc, db, weights=weights),
+ )
+ assert_allclose(
+ xr.cov(da, db + dc, weights=weights),
+ xr.cov(da, db, weights=weights) + xr.cov(da, dc, weights=weights),
+ )
+ assert_allclose(
+ xr.cov(k * da, db, weights=weights), k * xr.cov(da, db, weights=weights)
+ )
+ assert_allclose(
+ xr.cov(da, k * db, weights=weights), k * xr.cov(da, db, weights=weights)
+ )
+
+ # Test correlation properties
+ assert_allclose(
+ xr.corr(da + k, db, weights=weights), xr.corr(da, db, weights=weights)
+ )
+ assert_allclose(
+ xr.corr(da, db + k, weights=weights), xr.corr(da, db, weights=weights)
+ )
+ assert_allclose(
+ xr.corr(k * da, db, weights=weights), xr.corr(da, db, weights=weights)
+ )
+ assert_allclose(
+ xr.corr(da, k * db, weights=weights), xr.corr(da, db, weights=weights)
+ )
+
+
+def test_equally_weighted_cov_corr() -> None:
+ # Test that equal weights for all values produces same results as weights=None
+ da = xr.DataArray(
+ np.random.random((3, 21, 4)),
+ coords={"time": pd.date_range("2000-01-01", freq="1D", periods=21)},
+ dims=("a", "time", "x"),
+ )
+ db = xr.DataArray(
+ np.random.random((3, 21, 4)),
+ coords={"time": pd.date_range("2000-01-01", freq="1D", periods=21)},
+ dims=("a", "time", "x"),
+ )
+ #
+ assert_allclose(
+ xr.cov(da, db, weights=None), xr.cov(da, db, weights=xr.DataArray(1))
+ )
+ assert_allclose(
+ xr.cov(da, db, weights=None), xr.cov(da, db, weights=xr.DataArray(2))
+ )
+ assert_allclose(
+ xr.corr(da, db, weights=None), xr.corr(da, db, weights=xr.DataArray(1))
+ )
+ assert_allclose(
+ xr.corr(da, db, weights=None), xr.corr(da, db, weights=xr.DataArray(2))
+ )
+
+
@requires_dask
def test_vectorize_dask_new_output_dims() -> None:
# regression test for GH3574
@@ -1936,7 +2027,7 @@ def test_dot(use_dask: bool) -> None:
da_a = da_a.chunk({"a": 3})
da_b = da_b.chunk({"a": 3})
da_c = da_c.chunk({"c": 3})
- actual = xr.dot(da_a, da_b, dims=["a", "b"])
+ actual = xr.dot(da_a, da_b, dim=["a", "b"])
assert actual.dims == ("c",)
assert (actual.data == np.einsum("ij,ijk->k", a, b)).all()
assert isinstance(actual.variable.data, type(da_a.variable.data))
@@ -1960,33 +2051,33 @@ def test_dot(use_dask: bool) -> None:
if use_dask:
da_a = da_a.chunk({"a": 3})
da_b = da_b.chunk({"a": 3})
- actual = xr.dot(da_a, da_b, dims=["b"])
+ actual = xr.dot(da_a, da_b, dim=["b"])
assert actual.dims == ("a", "c")
assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all()
assert isinstance(actual.variable.data, type(da_a.variable.data))
- actual = xr.dot(da_a, da_b, dims=["b"])
+ actual = xr.dot(da_a, da_b, dim=["b"])
assert actual.dims == ("a", "c")
assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all()
- actual = xr.dot(da_a, da_b, dims="b")
+ actual = xr.dot(da_a, da_b, dim="b")
assert actual.dims == ("a", "c")
assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all()
- actual = xr.dot(da_a, da_b, dims="a")
+ actual = xr.dot(da_a, da_b, dim="a")
assert actual.dims == ("b", "c")
assert (actual.data == np.einsum("ij,ijk->jk", a, b)).all()
- actual = xr.dot(da_a, da_b, dims="c")
+ actual = xr.dot(da_a, da_b, dim="c")
assert actual.dims == ("a", "b")
assert (actual.data == np.einsum("ij,ijk->ij", a, b)).all()
- actual = xr.dot(da_a, da_b, da_c, dims=["a", "b"])
+ actual = xr.dot(da_a, da_b, da_c, dim=["a", "b"])
assert actual.dims == ("c", "e")
assert (actual.data == np.einsum("ij,ijk,kl->kl ", a, b, c)).all()
# should work with tuple
- actual = xr.dot(da_a, da_b, dims=("c",))
+ actual = xr.dot(da_a, da_b, dim=("c",))
assert actual.dims == ("a", "b")
assert (actual.data == np.einsum("ij,ijk->ij", a, b)).all()
@@ -1996,47 +2087,47 @@ def test_dot(use_dask: bool) -> None:
assert (actual.data == np.einsum("ij,ijk,kl->l ", a, b, c)).all()
# 1 array summation
- actual = xr.dot(da_a, dims="a")
+ actual = xr.dot(da_a, dim="a")
assert actual.dims == ("b",)
assert (actual.data == np.einsum("ij->j ", a)).all()
# empty dim
- actual = xr.dot(da_a.sel(a=[]), da_a.sel(a=[]), dims="a")
+ actual = xr.dot(da_a.sel(a=[]), da_a.sel(a=[]), dim="a")
assert actual.dims == ("b",)
assert (actual.data == np.zeros(actual.shape)).all()
# Ellipsis (...) sums over all dimensions
- actual = xr.dot(da_a, da_b, dims=...)
+ actual = xr.dot(da_a, da_b, dim=...)
assert actual.dims == ()
assert (actual.data == np.einsum("ij,ijk->", a, b)).all()
- actual = xr.dot(da_a, da_b, da_c, dims=...)
+ actual = xr.dot(da_a, da_b, da_c, dim=...)
assert actual.dims == ()
assert (actual.data == np.einsum("ij,ijk,kl-> ", a, b, c)).all()
- actual = xr.dot(da_a, dims=...)
+ actual = xr.dot(da_a, dim=...)
assert actual.dims == ()
assert (actual.data == np.einsum("ij-> ", a)).all()
- actual = xr.dot(da_a.sel(a=[]), da_a.sel(a=[]), dims=...)
+ actual = xr.dot(da_a.sel(a=[]), da_a.sel(a=[]), dim=...)
assert actual.dims == ()
assert (actual.data == np.zeros(actual.shape)).all()
# Invalid cases
if not use_dask:
with pytest.raises(TypeError):
- xr.dot(da_a, dims="a", invalid=None)
+ xr.dot(da_a, dim="a", invalid=None)
with pytest.raises(TypeError):
- xr.dot(da_a.to_dataset(name="da"), dims="a")
+ xr.dot(da_a.to_dataset(name="da"), dim="a")
with pytest.raises(TypeError):
- xr.dot(dims="a")
+ xr.dot(dim="a")
# einsum parameters
- actual = xr.dot(da_a, da_b, dims=["b"], order="C")
+ actual = xr.dot(da_a, da_b, dim=["b"], order="C")
assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all()
assert actual.values.flags["C_CONTIGUOUS"]
assert not actual.values.flags["F_CONTIGUOUS"]
- actual = xr.dot(da_a, da_b, dims=["b"], order="F")
+ actual = xr.dot(da_a, da_b, dim=["b"], order="F")
assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all()
# dask converts Fortran arrays to C order when merging the final array
if not use_dask:
@@ -2078,7 +2169,7 @@ def test_dot_align_coords(use_dask: bool) -> None:
expected = (da_a * da_b).sum(["a", "b"])
xr.testing.assert_allclose(expected, actual)
- actual = xr.dot(da_a, da_b, dims=...)
+ actual = xr.dot(da_a, da_b, dim=...)
expected = (da_a * da_b).sum()
xr.testing.assert_allclose(expected, actual)
diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py
index 92415631748..d1fc085bf0f 100644
--- a/xarray/tests/test_concat.py
+++ b/xarray/tests/test_concat.py
@@ -509,7 +509,7 @@ def test_concat_coords_kwarg(self, data, dim, coords) -> None:
actual = concat(datasets, data[dim], coords=coords)
if coords == "all":
- expected = np.array([data["extra"].values for _ in range(data.dims[dim])])
+ expected = np.array([data["extra"].values for _ in range(data.sizes[dim])])
assert_array_equal(actual["extra"].values, expected)
else:
@@ -1214,7 +1214,7 @@ def test_concat_preserve_coordinate_order() -> None:
# check dimension order
for act, exp in zip(actual.dims, expected.dims):
assert act == exp
- assert actual.dims[act] == expected.dims[exp]
+ assert actual.sizes[act] == expected.sizes[exp]
# check coordinate order
for act, exp in zip(actual.coords, expected.coords):
diff --git a/xarray/tests/test_conventions.py b/xarray/tests/test_conventions.py
index d6d1303a696..be6e949edf8 100644
--- a/xarray/tests/test_conventions.py
+++ b/xarray/tests/test_conventions.py
@@ -495,6 +495,18 @@ def test_encoding_kwarg_fixed_width_string(self) -> None:
pass
+@pytest.mark.parametrize(
+ "data",
+ [
+ np.array([["ab", "cdef", b"X"], [1, 2, "c"]], dtype=object),
+ np.array([["x", 1], ["y", 2]], dtype="object"),
+ ],
+)
+def test_infer_dtype_error_on_mixed_types(data):
+ with pytest.raises(ValueError, match="unable to infer dtype on variable"):
+ conventions._infer_dtype(data, "test")
+
+
class TestDecodeCFVariableWithArrayUnits:
def test_decode_cf_variable_with_array_units(self) -> None:
v = Variable(["t"], [1, 2, 3], {"units": np.array(["foobar"], dtype=object)})
diff --git a/xarray/tests/test_coordinates.py b/xarray/tests/test_coordinates.py
index ef73371dfe4..68ce55b05da 100644
--- a/xarray/tests/test_coordinates.py
+++ b/xarray/tests/test_coordinates.py
@@ -79,9 +79,10 @@ def test_from_pandas_multiindex(self) -> None:
for name in ("x", "one", "two"):
assert_identical(expected[name], coords.variables[name])
+ @pytest.mark.filterwarnings("ignore:return type")
def test_dims(self) -> None:
coords = Coordinates(coords={"x": [0, 1, 2]})
- assert coords.dims == {"x": 3}
+ assert set(coords.dims) == {"x"}
def test_sizes(self) -> None:
coords = Coordinates(coords={"x": [0, 1, 2]})
diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py
index c2a77c97d85..137d6020829 100644
--- a/xarray/tests/test_dask.py
+++ b/xarray/tests/test_dask.py
@@ -1367,6 +1367,25 @@ def test_map_blocks_da_ds_with_template(obj):
assert_identical(actual, template)
+def test_map_blocks_roundtrip_string_index():
+ ds = xr.Dataset(
+ {"data": (["label"], [1, 2, 3])}, coords={"label": ["foo", "bar", "baz"]}
+ ).chunk(label=1)
+ assert ds.label.dtype == np.dtype(" None:
actual = renamed.drop_vars("foo", errors="ignore")
assert_identical(actual, renamed)
+ def test_drop_vars_callable(self) -> None:
+ A = DataArray(
+ np.random.randn(2, 3), dims=["x", "y"], coords={"x": [1, 2], "y": [3, 4, 5]}
+ )
+ expected = A.drop_vars(["x", "y"])
+ actual = A.drop_vars(lambda x: x.indexes)
+ assert_identical(expected, actual)
+
def test_drop_multiindex_level(self) -> None:
# GH6505
expected = self.mda.drop_vars(["x", "level_1", "level_2"])
@@ -3964,13 +3972,13 @@ def test_dot(self) -> None:
assert_equal(expected3, actual3)
# Ellipsis: all dims are shared
- actual4 = da.dot(da, dims=...)
+ actual4 = da.dot(da, dim=...)
expected4 = da.dot(da)
assert_equal(expected4, actual4)
# Ellipsis: not all dims are shared
- actual5 = da.dot(dm3, dims=...)
- expected5 = da.dot(dm3, dims=("j", "x", "y", "z"))
+ actual5 = da.dot(dm3, dim=...)
+ expected5 = da.dot(dm3, dim=("j", "x", "y", "z"))
assert_equal(expected5, actual5)
with pytest.raises(NotImplementedError):
diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py
index ff7703a1cf5..664d108b89c 100644
--- a/xarray/tests/test_dataset.py
+++ b/xarray/tests/test_dataset.py
@@ -687,6 +687,7 @@ class CustomIndex(Index):
# test coordinate variables copied
assert ds.variables["x"] is not coords.variables["x"]
+ @pytest.mark.filterwarnings("ignore:return type")
def test_properties(self) -> None:
ds = create_test_data()
@@ -694,10 +695,11 @@ def test_properties(self) -> None:
# These exact types aren't public API, but this makes sure we don't
# change them inadvertently:
assert isinstance(ds.dims, utils.Frozen)
+ # TODO change after deprecation cycle in GH #8500 is complete
assert isinstance(ds.dims.mapping, dict)
assert type(ds.dims.mapping) is dict # noqa: E721
- assert ds.dims == {"dim1": 8, "dim2": 9, "dim3": 10, "time": 20}
- assert ds.sizes == ds.dims
+ assert ds.dims == ds.sizes
+ assert ds.sizes == {"dim1": 8, "dim2": 9, "dim3": 10, "time": 20}
# dtypes
assert isinstance(ds.dtypes, utils.Frozen)
@@ -749,6 +751,27 @@ def test_properties(self) -> None:
== 16
)
+ def test_warn_ds_dims_deprecation(self) -> None:
+ # TODO remove after deprecation cycle in GH #8500 is complete
+ ds = create_test_data()
+
+ with pytest.warns(FutureWarning, match="return type"):
+ ds.dims["dim1"]
+
+ with pytest.warns(FutureWarning, match="return type"):
+ ds.dims.keys()
+
+ with pytest.warns(FutureWarning, match="return type"):
+ ds.dims.values()
+
+ with pytest.warns(FutureWarning, match="return type"):
+ ds.dims.items()
+
+ with assert_no_warnings():
+ len(ds.dims)
+ ds.dims.__iter__()
+ "dim1" in ds.dims
+
def test_asarray(self) -> None:
ds = Dataset({"x": 0})
with pytest.raises(TypeError, match=r"cannot directly convert"):
@@ -804,7 +827,7 @@ def test_modify_inplace(self) -> None:
b = Dataset()
b["x"] = ("x", vec, attributes)
assert_identical(a["x"], b["x"])
- assert a.dims == b.dims
+ assert a.sizes == b.sizes
# this should work
a["x"] = ("x", vec[:5])
a["z"] = ("x", np.arange(5))
@@ -865,7 +888,7 @@ def test_coords_properties(self) -> None:
assert expected == actual
# dims
- assert coords.dims == {"x": 2, "y": 3}
+ assert coords.sizes == {"x": 2, "y": 3}
# dtypes
assert coords.dtypes == {
@@ -1215,9 +1238,9 @@ def test_isel(self) -> None:
assert list(data.dims) == list(ret.dims)
for d in data.dims:
if d in slicers:
- assert ret.dims[d] == np.arange(data.dims[d])[slicers[d]].size
+ assert ret.sizes[d] == np.arange(data.sizes[d])[slicers[d]].size
else:
- assert data.dims[d] == ret.dims[d]
+ assert data.sizes[d] == ret.sizes[d]
# Verify that the data is what we expect
for v in data.variables:
assert data[v].dims == ret[v].dims
@@ -1251,19 +1274,19 @@ def test_isel(self) -> None:
assert_identical(data, data.isel(not_a_dim=slice(0, 2), missing_dims="ignore"))
ret = data.isel(dim1=0)
- assert {"time": 20, "dim2": 9, "dim3": 10} == ret.dims
+ assert {"time": 20, "dim2": 9, "dim3": 10} == ret.sizes
assert set(data.data_vars) == set(ret.data_vars)
assert set(data.coords) == set(ret.coords)
assert set(data.xindexes) == set(ret.xindexes)
ret = data.isel(time=slice(2), dim1=0, dim2=slice(5))
- assert {"time": 2, "dim2": 5, "dim3": 10} == ret.dims
+ assert {"time": 2, "dim2": 5, "dim3": 10} == ret.sizes
assert set(data.data_vars) == set(ret.data_vars)
assert set(data.coords) == set(ret.coords)
assert set(data.xindexes) == set(ret.xindexes)
ret = data.isel(time=0, dim1=0, dim2=slice(5))
- assert {"dim2": 5, "dim3": 10} == ret.dims
+ assert {"dim2": 5, "dim3": 10} == ret.sizes
assert set(data.data_vars) == set(ret.data_vars)
assert set(data.coords) == set(ret.coords)
assert set(data.xindexes) == set(list(ret.xindexes) + ["time"])
@@ -2651,19 +2674,19 @@ def test_drop_variables(self) -> None:
# deprecated approach with `drop` works (straight copy paste from above)
- with pytest.warns(PendingDeprecationWarning):
+ with pytest.warns(DeprecationWarning):
actual = data.drop("not_found_here", errors="ignore")
assert_identical(data, actual)
- with pytest.warns(PendingDeprecationWarning):
+ with pytest.warns(DeprecationWarning):
actual = data.drop(["not_found_here"], errors="ignore")
assert_identical(data, actual)
- with pytest.warns(PendingDeprecationWarning):
+ with pytest.warns(DeprecationWarning):
actual = data.drop(["time", "not_found_here"], errors="ignore")
assert_identical(expected, actual)
- with pytest.warns(PendingDeprecationWarning):
+ with pytest.warns(DeprecationWarning):
actual = data.drop({"time", "not_found_here"}, errors="ignore")
assert_identical(expected, actual)
@@ -2736,9 +2759,9 @@ def test_drop_labels_by_keyword(self) -> None:
ds5 = data.drop_sel(x=["a", "b"], y=range(0, 6, 2))
arr = DataArray(range(3), dims=["c"])
- with pytest.warns(FutureWarning):
+ with pytest.warns(DeprecationWarning):
data.drop(arr.coords)
- with pytest.warns(FutureWarning):
+ with pytest.warns(DeprecationWarning):
data.drop(arr.xindexes)
assert_array_equal(ds1.coords["x"], ["b"])
@@ -4697,6 +4720,17 @@ def test_from_dataframe_categorical(self) -> None:
assert len(ds["i1"]) == 2
assert len(ds["i2"]) == 2
+ def test_from_dataframe_categorical_string_categories(self) -> None:
+ cat = pd.CategoricalIndex(
+ pd.Categorical.from_codes(
+ np.array([1, 1, 0, 2]),
+ categories=pd.Index(["foo", "bar", "baz"], dtype="string"),
+ )
+ )
+ ser = pd.Series(1, index=cat)
+ ds = ser.to_xarray()
+ assert ds.coords.dtypes["index"] == np.dtype("O")
+
@requires_sparse
def test_from_dataframe_sparse(self) -> None:
import sparse
@@ -4960,7 +4994,7 @@ def test_pickle(self) -> None:
roundtripped = pickle.loads(pickle.dumps(data))
assert_identical(data, roundtripped)
# regression test for #167:
- assert data.dims == roundtripped.dims
+ assert data.sizes == roundtripped.sizes
def test_lazy_load(self) -> None:
store = InaccessibleVariableDataStore()
@@ -5418,7 +5452,7 @@ def test_reduce_non_numeric(self) -> None:
data2 = create_test_data(seed=44)
add_vars = {"var4": ["dim1", "dim2"], "var5": ["dim1"]}
for v, dims in sorted(add_vars.items()):
- size = tuple(data1.dims[d] for d in dims)
+ size = tuple(data1.sizes[d] for d in dims)
data = np.random.randint(0, 100, size=size).astype(np.str_)
data1[v] = (dims, data, {"foo": "variable"})
@@ -6467,7 +6501,7 @@ def test_pad(self) -> None:
assert padded["var1"].shape == (8, 11)
assert padded["var2"].shape == (8, 11)
assert padded["var3"].shape == (10, 8)
- assert dict(padded.dims) == {"dim1": 8, "dim2": 11, "dim3": 10, "time": 20}
+ assert dict(padded.sizes) == {"dim1": 8, "dim2": 11, "dim3": 10, "time": 20}
np.testing.assert_equal(padded["var1"].isel(dim2=[0, -1]).data, 42)
np.testing.assert_equal(padded["dim2"][[0, -1]].data, np.nan)
@@ -6684,6 +6718,23 @@ def test_query(self, backend, engine, parser) -> None:
# pytest tests — new tests should go here, rather than in the class.
+@pytest.mark.parametrize("parser", ["pandas", "python"])
+def test_eval(ds, parser) -> None:
+ """Currently much more minimal testing that `query` above, and much of the setup
+ isn't used. But the risks are fairly low — `query` shares much of the code, and
+ the method is currently experimental."""
+
+ actual = ds.eval("z1 + 5", parser=parser)
+ expect = ds["z1"] + 5
+ assert_identical(expect, actual)
+
+ # check pandas query syntax is supported
+ if parser == "pandas":
+ actual = ds.eval("(z1 > 5) and (z2 > 0)", parser=parser)
+ expect = (ds["z1"] > 5) & (ds["z2"] > 0)
+ assert_identical(expect, actual)
+
+
@pytest.mark.parametrize("test_elements", ([1, 2], np.array([1, 2]), DataArray([1, 2])))
def test_isin(test_elements, backend) -> None:
expected = Dataset(
@@ -7162,7 +7213,7 @@ def test_clip(ds) -> None:
assert all((result.max(...) <= 0.75).values())
result = ds.clip(min=ds.mean("y"), max=ds.mean("y"))
- assert result.dims == ds.dims
+ assert result.sizes == ds.sizes
class TestDropDuplicates:
diff --git a/xarray/tests/test_error_messages.py b/xarray/tests/test_error_messages.py
new file mode 100644
index 00000000000..b5840aafdfa
--- /dev/null
+++ b/xarray/tests/test_error_messages.py
@@ -0,0 +1,17 @@
+"""
+This new file is intended to test the quality & friendliness of error messages that are
+raised by xarray. It's currently separate from the standard tests, which are more
+focused on the functions working (though we could consider integrating them.).
+"""
+
+import pytest
+
+
+def test_no_var_in_dataset(ds):
+ with pytest.raises(
+ KeyError,
+ match=(
+ r"No variable named 'foo'. Variables on the dataset include \['z1', 'z2', 'x', 'time', 'c', 'y'\]"
+ ),
+ ):
+ ds["foo"]
diff --git a/xarray/tests/test_formatting.py b/xarray/tests/test_formatting.py
index 96bb9c8a3a7..181b0205352 100644
--- a/xarray/tests/test_formatting.py
+++ b/xarray/tests/test_formatting.py
@@ -773,3 +773,33 @@ def __array__(self, dtype=None):
# These will crash if var.data are converted to numpy arrays:
var.__repr__()
var._repr_html_()
+
+
+@pytest.mark.parametrize("as_dataset", (False, True))
+def test_format_xindexes_none(as_dataset: bool) -> None:
+ # ensure repr for empty xindexes can be displayed #8367
+
+ expected = """\
+ Indexes:
+ *empty*"""
+ expected = dedent(expected)
+
+ obj: xr.DataArray | xr.Dataset = xr.DataArray()
+ obj = obj._to_temp_dataset() if as_dataset else obj
+
+ actual = repr(obj.xindexes)
+ assert actual == expected
+
+
+@pytest.mark.parametrize("as_dataset", (False, True))
+def test_format_xindexes(as_dataset: bool) -> None:
+ expected = """\
+ Indexes:
+ x PandasIndex"""
+ expected = dedent(expected)
+
+ obj: xr.DataArray | xr.Dataset = xr.DataArray([1], coords={"x": [1]})
+ obj = obj._to_temp_dataset() if as_dataset else obj
+
+ actual = repr(obj.xindexes)
+ assert actual == expected
diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py
index b166992deb1..84820d56c45 100644
--- a/xarray/tests/test_groupby.py
+++ b/xarray/tests/test_groupby.py
@@ -59,6 +59,7 @@ def test_consolidate_slices() -> None:
_consolidate_slices([slice(3), 4]) # type: ignore[list-item]
+@pytest.mark.filterwarnings("ignore:return type")
def test_groupby_dims_property(dataset) -> None:
assert dataset.groupby("x").dims == dataset.isel(x=1).dims
assert dataset.groupby("y").dims == dataset.isel(y=1).dims
@@ -67,6 +68,14 @@ def test_groupby_dims_property(dataset) -> None:
assert stacked.groupby("xy").dims == stacked.isel(xy=0).dims
+def test_groupby_sizes_property(dataset) -> None:
+ assert dataset.groupby("x").sizes == dataset.isel(x=1).sizes
+ assert dataset.groupby("y").sizes == dataset.isel(y=1).sizes
+
+ stacked = dataset.stack({"xy": ("x", "y")})
+ assert stacked.groupby("xy").sizes == stacked.isel(xy=0).sizes
+
+
def test_multi_index_groupby_map(dataset) -> None:
# regression test for GH873
ds = dataset.isel(z=1, drop=True)[["foo"]]
diff --git a/xarray/tests/test_missing.py b/xarray/tests/test_missing.py
index e318bf01a7e..45a649605f3 100644
--- a/xarray/tests/test_missing.py
+++ b/xarray/tests/test_missing.py
@@ -24,6 +24,8 @@
requires_bottleneck,
requires_cftime,
requires_dask,
+ requires_numbagg,
+ requires_numbagg_or_bottleneck,
requires_scipy,
)
@@ -407,7 +409,7 @@ def test_interpolate_dask_expected_dtype(dtype, method):
assert da.dtype == da.compute().dtype
-@requires_bottleneck
+@requires_numbagg_or_bottleneck
def test_ffill():
da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x")
expected = xr.DataArray(np.array([4, 5, 5], dtype=np.float64), dims="x")
@@ -415,9 +417,9 @@ def test_ffill():
assert_equal(actual, expected)
-def test_ffill_use_bottleneck():
+def test_ffill_use_bottleneck_numbagg():
da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x")
- with xr.set_options(use_bottleneck=False):
+ with xr.set_options(use_bottleneck=False, use_numbagg=False):
with pytest.raises(RuntimeError):
da.ffill("x")
@@ -426,14 +428,24 @@ def test_ffill_use_bottleneck():
def test_ffill_use_bottleneck_dask():
da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x")
da = da.chunk({"x": 1})
- with xr.set_options(use_bottleneck=False):
+ with xr.set_options(use_bottleneck=False, use_numbagg=False):
with pytest.raises(RuntimeError):
da.ffill("x")
+@requires_numbagg
+@requires_dask
+def test_ffill_use_numbagg_dask():
+ with xr.set_options(use_bottleneck=False):
+ da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x")
+ da = da.chunk(x=-1)
+ # Succeeds with a single chunk:
+ _ = da.ffill("x").compute()
+
+
def test_bfill_use_bottleneck():
da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x")
- with xr.set_options(use_bottleneck=False):
+ with xr.set_options(use_bottleneck=False, use_numbagg=False):
with pytest.raises(RuntimeError):
da.bfill("x")
@@ -442,7 +454,7 @@ def test_bfill_use_bottleneck():
def test_bfill_use_bottleneck_dask():
da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x")
da = da.chunk({"x": 1})
- with xr.set_options(use_bottleneck=False):
+ with xr.set_options(use_bottleneck=False, use_numbagg=False):
with pytest.raises(RuntimeError):
da.bfill("x")
@@ -536,7 +548,7 @@ def test_ffill_limit():
def test_interpolate_dataset(ds):
actual = ds.interpolate_na(dim="time")
# no missing values in var1
- assert actual["var1"].count("time") == actual.dims["time"]
+ assert actual["var1"].count("time") == actual.sizes["time"]
# var2 should be the same as it was
assert_array_equal(actual["var2"], ds["var2"])
diff --git a/xarray/tests/test_namedarray.py b/xarray/tests/test_namedarray.py
index e0141e12755..9aedcbc80d4 100644
--- a/xarray/tests/test_namedarray.py
+++ b/xarray/tests/test_namedarray.py
@@ -2,6 +2,7 @@
import copy
import warnings
+from abc import abstractmethod
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, Generic, cast, overload
@@ -9,9 +10,13 @@
import pytest
from xarray.core.indexing import ExplicitlyIndexed
-from xarray.namedarray._typing import _arrayfunction_or_api, _DType_co, _ShapeType_co
+from xarray.namedarray._typing import (
+ _arrayfunction_or_api,
+ _default,
+ _DType_co,
+ _ShapeType_co,
+)
from xarray.namedarray.core import NamedArray, from_array
-from xarray.namedarray.utils import _default
if TYPE_CHECKING:
from types import ModuleType
@@ -19,13 +24,14 @@
from numpy.typing import ArrayLike, DTypeLike, NDArray
from xarray.namedarray._typing import (
+ Default,
_AttrsLike,
_DimsLike,
_DType,
+ _IndexKeyLike,
_Shape,
duckarray,
)
- from xarray.namedarray.utils import Default
class CustomArrayBase(Generic[_ShapeType_co, _DType_co]):
@@ -53,391 +59,441 @@ class CustomArrayIndexable(
ExplicitlyIndexed,
Generic[_ShapeType_co, _DType_co],
):
+ def __getitem__(
+ self, key: _IndexKeyLike | CustomArrayIndexable[Any, Any], /
+ ) -> CustomArrayIndexable[Any, _DType_co]:
+ if isinstance(key, CustomArrayIndexable):
+ if isinstance(key.array, type(self.array)):
+ # TODO: key.array is duckarray here, can it be narrowed down further?
+ # an _arrayapi cannot be used on a _arrayfunction for example.
+ return type(self)(array=self.array[key.array]) # type: ignore[index]
+ else:
+ raise TypeError("key must have the same array type as self")
+ else:
+ return type(self)(array=self.array[key])
+
def __array_namespace__(self) -> ModuleType:
return np
-@pytest.fixture
-def random_inputs() -> np.ndarray[Any, np.dtype[np.float32]]:
- return np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5))
-
-
-def test_namedarray_init() -> None:
- dtype = np.dtype(np.int8)
- expected = np.array([1, 2], dtype=dtype)
- actual: NamedArray[Any, np.dtype[np.int8]]
- actual = NamedArray(("x",), expected)
- assert np.array_equal(np.asarray(actual.data), expected)
-
- with pytest.raises(AttributeError):
- expected2 = [1, 2]
- actual2: NamedArray[Any, Any]
- actual2 = NamedArray(("x",), expected2) # type: ignore[arg-type]
- assert np.array_equal(np.asarray(actual2.data), expected2)
-
-
-@pytest.mark.parametrize(
- "dims, data, expected, raise_error",
- [
- (("x",), [1, 2, 3], np.array([1, 2, 3]), False),
- ((1,), np.array([4, 5, 6]), np.array([4, 5, 6]), False),
- ((), 2, np.array(2), False),
- # Fail:
- (("x",), NamedArray("time", np.array([1, 2, 3])), np.array([1, 2, 3]), True),
- ],
-)
-def test_from_array(
- dims: _DimsLike,
- data: ArrayLike,
- expected: np.ndarray[Any, Any],
- raise_error: bool,
-) -> None:
- actual: NamedArray[Any, Any]
- if raise_error:
- with pytest.raises(TypeError, match="already a Named array"):
- actual = from_array(dims, data)
-
- # Named arrays are not allowed:
- from_array(actual) # type: ignore[call-overload]
- else:
- actual = from_array(dims, data)
-
+class NamedArraySubclassobjects:
+ @pytest.fixture
+ def target(self, data: np.ndarray[Any, Any]) -> Any:
+ """Fixture that needs to be overridden"""
+ raise NotImplementedError
+
+ @abstractmethod
+ def cls(self, *args: Any, **kwargs: Any) -> Any:
+ """Method that needs to be overridden"""
+ raise NotImplementedError
+
+ @pytest.fixture
+ def data(self) -> np.ndarray[Any, np.dtype[Any]]:
+ return 0.5 * np.arange(10).reshape(2, 5)
+
+ @pytest.fixture
+ def random_inputs(self) -> np.ndarray[Any, np.dtype[np.float32]]:
+ return np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5))
+
+ def test_properties(self, target: Any, data: Any) -> None:
+ assert target.dims == ("x", "y")
+ assert np.array_equal(target.data, data)
+ assert target.dtype == float
+ assert target.shape == (2, 5)
+ assert target.ndim == 2
+ assert target.sizes == {"x": 2, "y": 5}
+ assert target.size == 10
+ assert target.nbytes == 80
+ assert len(target) == 2
+
+ def test_attrs(self, target: Any) -> None:
+ assert target.attrs == {}
+ attrs = {"foo": "bar"}
+ target.attrs = attrs
+ assert target.attrs == attrs
+ assert isinstance(target.attrs, dict)
+ target.attrs["foo"] = "baz"
+ assert target.attrs["foo"] == "baz"
+
+ @pytest.mark.parametrize(
+ "expected", [np.array([1, 2], dtype=np.dtype(np.int8)), [1, 2]]
+ )
+ def test_init(self, expected: Any) -> None:
+ actual = self.cls(("x",), expected)
assert np.array_equal(np.asarray(actual.data), expected)
+ actual = self.cls(("x",), expected)
+ assert np.array_equal(np.asarray(actual.data), expected)
-def test_from_array_with_masked_array() -> None:
- masked_array: np.ndarray[Any, np.dtype[np.generic]]
- masked_array = np.ma.array([1, 2, 3], mask=[False, True, False]) # type: ignore[no-untyped-call]
- with pytest.raises(NotImplementedError):
- from_array(("x",), masked_array)
-
-
-def test_from_array_with_0d_object() -> None:
- data = np.empty((), dtype=object)
- data[()] = (10, 12, 12)
- narr = from_array((), data)
- np.array_equal(np.asarray(narr.data), data)
-
-
-# TODO: Make xr.core.indexing.ExplicitlyIndexed pass as a subclass of_arrayfunction_or_api
-# and remove this test.
-def test_from_array_with_explicitly_indexed(
- random_inputs: np.ndarray[Any, Any]
-) -> None:
- array: CustomArray[Any, Any]
- array = CustomArray(random_inputs)
- output: NamedArray[Any, Any]
- output = from_array(("x", "y", "z"), array)
- assert isinstance(output.data, np.ndarray)
-
- array2: CustomArrayIndexable[Any, Any]
- array2 = CustomArrayIndexable(random_inputs)
- output2: NamedArray[Any, Any]
- output2 = from_array(("x", "y", "z"), array2)
- assert isinstance(output2.data, CustomArrayIndexable)
-
-
-def test_properties() -> None:
- data = 0.5 * np.arange(10).reshape(2, 5)
- named_array: NamedArray[Any, Any]
- named_array = NamedArray(["x", "y"], data, {"key": "value"})
- assert named_array.dims == ("x", "y")
- assert np.array_equal(np.asarray(named_array.data), data)
- assert named_array.attrs == {"key": "value"}
- assert named_array.ndim == 2
- assert named_array.sizes == {"x": 2, "y": 5}
- assert named_array.size == 10
- assert named_array.nbytes == 80
- assert len(named_array) == 2
-
-
-def test_attrs() -> None:
- named_array: NamedArray[Any, Any]
- named_array = NamedArray(["x", "y"], np.arange(10).reshape(2, 5))
- assert named_array.attrs == {}
- named_array.attrs["key"] = "value"
- assert named_array.attrs == {"key": "value"}
- named_array.attrs = {"key": "value2"}
- assert named_array.attrs == {"key": "value2"}
-
-
-def test_data(random_inputs: np.ndarray[Any, Any]) -> None:
- named_array: NamedArray[Any, Any]
- named_array = NamedArray(["x", "y", "z"], random_inputs)
- assert np.array_equal(np.asarray(named_array.data), random_inputs)
- with pytest.raises(ValueError):
- named_array.data = np.random.random((3, 4)).astype(np.float64)
-
-
-def test_real_and_imag() -> None:
- expected_real: np.ndarray[Any, np.dtype[np.float64]]
- expected_real = np.arange(3, dtype=np.float64)
-
- expected_imag: np.ndarray[Any, np.dtype[np.float64]]
- expected_imag = -np.arange(3, dtype=np.float64)
-
- arr: np.ndarray[Any, np.dtype[np.complex128]]
- arr = expected_real + 1j * expected_imag
-
- named_array: NamedArray[Any, np.dtype[np.complex128]]
- named_array = NamedArray(["x"], arr)
-
- actual_real: duckarray[Any, np.dtype[np.float64]] = named_array.real.data
- assert np.array_equal(np.asarray(actual_real), expected_real)
- assert actual_real.dtype == expected_real.dtype
-
- actual_imag: duckarray[Any, np.dtype[np.float64]] = named_array.imag.data
- assert np.array_equal(np.asarray(actual_imag), expected_imag)
- assert actual_imag.dtype == expected_imag.dtype
-
-
-# Additional tests as per your original class-based code
-@pytest.mark.parametrize(
- "data, dtype",
- [
- ("foo", np.dtype("U3")),
- (b"foo", np.dtype("S3")),
- ],
-)
-def test_0d_string(data: Any, dtype: DTypeLike) -> None:
- named_array: NamedArray[Any, Any]
- named_array = from_array([], data)
- assert named_array.data == data
- assert named_array.dims == ()
- assert named_array.sizes == {}
- assert named_array.attrs == {}
- assert named_array.ndim == 0
- assert named_array.size == 1
- assert named_array.dtype == dtype
-
-
-def test_0d_object() -> None:
- named_array: NamedArray[Any, Any]
- named_array = from_array([], (10, 12, 12))
- expected_data = np.empty((), dtype=object)
- expected_data[()] = (10, 12, 12)
- assert np.array_equal(np.asarray(named_array.data), expected_data)
-
- assert named_array.dims == ()
- assert named_array.sizes == {}
- assert named_array.attrs == {}
- assert named_array.ndim == 0
- assert named_array.size == 1
- assert named_array.dtype == np.dtype("O")
-
-
-def test_0d_datetime() -> None:
- named_array: NamedArray[Any, Any]
- named_array = from_array([], np.datetime64("2000-01-01"))
- assert named_array.dtype == np.dtype("datetime64[D]")
-
-
-@pytest.mark.parametrize(
- "timedelta, expected_dtype",
- [
- (np.timedelta64(1, "D"), np.dtype("timedelta64[D]")),
- (np.timedelta64(1, "s"), np.dtype("timedelta64[s]")),
- (np.timedelta64(1, "m"), np.dtype("timedelta64[m]")),
- (np.timedelta64(1, "h"), np.dtype("timedelta64[h]")),
- (np.timedelta64(1, "us"), np.dtype("timedelta64[us]")),
- (np.timedelta64(1, "ns"), np.dtype("timedelta64[ns]")),
- (np.timedelta64(1, "ps"), np.dtype("timedelta64[ps]")),
- (np.timedelta64(1, "fs"), np.dtype("timedelta64[fs]")),
- (np.timedelta64(1, "as"), np.dtype("timedelta64[as]")),
- ],
-)
-def test_0d_timedelta(
- timedelta: np.timedelta64, expected_dtype: np.dtype[np.timedelta64]
-) -> None:
- named_array: NamedArray[Any, Any]
- named_array = from_array([], timedelta)
- assert named_array.dtype == expected_dtype
- assert named_array.data == timedelta
-
-
-@pytest.mark.parametrize(
- "dims, data_shape, new_dims, raises",
- [
- (["x", "y", "z"], (2, 3, 4), ["a", "b", "c"], False),
- (["x", "y", "z"], (2, 3, 4), ["a", "b"], True),
- (["x", "y", "z"], (2, 4, 5), ["a", "b", "c", "d"], True),
- ([], [], (), False),
- ([], [], ("x",), True),
- ],
-)
-def test_dims_setter(dims: Any, data_shape: Any, new_dims: Any, raises: bool) -> None:
- named_array: NamedArray[Any, Any]
- named_array = NamedArray(dims, np.asarray(np.random.random(data_shape)))
- assert named_array.dims == tuple(dims)
- if raises:
+ def test_data(self, random_inputs: Any) -> None:
+ expected = self.cls(["x", "y", "z"], random_inputs)
+ assert np.array_equal(np.asarray(expected.data), random_inputs)
with pytest.raises(ValueError):
- named_array.dims = new_dims
- else:
- named_array.dims = new_dims
- assert named_array.dims == tuple(new_dims)
-
-
-def test_duck_array_class() -> None:
- def test_duck_array_typevar(a: duckarray[Any, _DType]) -> duckarray[Any, _DType]:
- # Mypy checks a is valid:
- b: duckarray[Any, _DType] = a
-
- # Runtime check if valid:
- if isinstance(b, _arrayfunction_or_api):
- return b
+ expected.data = np.random.random((3, 4)).astype(np.float64)
+ d2 = np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5))
+ expected.data = d2
+ assert np.array_equal(np.asarray(expected.data), d2)
+
+
+class TestNamedArray(NamedArraySubclassobjects):
+ def cls(self, *args: Any, **kwargs: Any) -> NamedArray[Any, Any]:
+ return NamedArray(*args, **kwargs)
+
+ @pytest.fixture
+ def target(self, data: np.ndarray[Any, Any]) -> NamedArray[Any, Any]:
+ return NamedArray(["x", "y"], data)
+
+ @pytest.mark.parametrize(
+ "expected",
+ [
+ np.array([1, 2], dtype=np.dtype(np.int8)),
+ pytest.param(
+ [1, 2],
+ marks=pytest.mark.xfail(
+ reason="NamedArray only supports array-like objects"
+ ),
+ ),
+ ],
+ )
+ def test_init(self, expected: Any) -> None:
+ super().test_init(expected)
+
+ @pytest.mark.parametrize(
+ "dims, data, expected, raise_error",
+ [
+ (("x",), [1, 2, 3], np.array([1, 2, 3]), False),
+ ((1,), np.array([4, 5, 6]), np.array([4, 5, 6]), False),
+ ((), 2, np.array(2), False),
+ # Fail:
+ (
+ ("x",),
+ NamedArray("time", np.array([1, 2, 3])),
+ np.array([1, 2, 3]),
+ True,
+ ),
+ ],
+ )
+ def test_from_array(
+ self,
+ dims: _DimsLike,
+ data: ArrayLike,
+ expected: np.ndarray[Any, Any],
+ raise_error: bool,
+ ) -> None:
+ actual: NamedArray[Any, Any]
+ if raise_error:
+ with pytest.raises(TypeError, match="already a Named array"):
+ actual = from_array(dims, data)
+
+ # Named arrays are not allowed:
+ from_array(actual) # type: ignore[call-overload]
else:
- raise TypeError(f"a ({type(a)}) is not a valid _arrayfunction or _arrayapi")
-
- numpy_a: NDArray[np.int64]
- numpy_a = np.array([2.1, 4], dtype=np.dtype(np.int64))
- test_duck_array_typevar(numpy_a)
-
- masked_a: np.ma.MaskedArray[Any, np.dtype[np.int64]]
- masked_a = np.ma.asarray([2.1, 4], dtype=np.dtype(np.int64)) # type: ignore[no-untyped-call]
- test_duck_array_typevar(masked_a)
-
- custom_a: CustomArrayIndexable[Any, np.dtype[np.int64]]
- custom_a = CustomArrayIndexable(numpy_a)
- test_duck_array_typevar(custom_a)
-
- # Test numpy's array api:
- with warnings.catch_warnings():
- warnings.filterwarnings(
- "ignore",
- r"The numpy.array_api submodule is still experimental",
- category=UserWarning,
- )
- import numpy.array_api as nxp
-
- # TODO: nxp doesn't use dtype typevars, so can only use Any for the moment:
- arrayapi_a: duckarray[Any, Any] # duckarray[Any, np.dtype[np.int64]]
- arrayapi_a = nxp.asarray([2.1, 4], dtype=np.dtype(np.int64))
- test_duck_array_typevar(arrayapi_a)
-
-
-def test_new_namedarray() -> None:
- dtype_float = np.dtype(np.float32)
- narr_float: NamedArray[Any, np.dtype[np.float32]]
- narr_float = NamedArray(("x",), np.array([1.5, 3.2], dtype=dtype_float))
- assert narr_float.dtype == dtype_float
-
- dtype_int = np.dtype(np.int8)
- narr_int: NamedArray[Any, np.dtype[np.int8]]
- narr_int = narr_float._new(("x",), np.array([1, 3], dtype=dtype_int))
- assert narr_int.dtype == dtype_int
-
- # Test with a subclass:
- class Variable(
- NamedArray[_ShapeType_co, _DType_co], Generic[_ShapeType_co, _DType_co]
- ):
- @overload
- def _new(
- self,
- dims: _DimsLike | Default = ...,
- data: duckarray[Any, _DType] = ...,
- attrs: _AttrsLike | Default = ...,
- ) -> Variable[Any, _DType]:
- ...
-
- @overload
- def _new(
- self,
- dims: _DimsLike | Default = ...,
- data: Default = ...,
- attrs: _AttrsLike | Default = ...,
- ) -> Variable[_ShapeType_co, _DType_co]:
- ...
-
- def _new(
- self,
- dims: _DimsLike | Default = _default,
- data: duckarray[Any, _DType] | Default = _default,
- attrs: _AttrsLike | Default = _default,
- ) -> Variable[Any, _DType] | Variable[_ShapeType_co, _DType_co]:
- dims_ = copy.copy(self._dims) if dims is _default else dims
-
- attrs_: Mapping[Any, Any] | None
- if attrs is _default:
- attrs_ = None if self._attrs is None else self._attrs.copy()
- else:
- attrs_ = attrs
-
- if data is _default:
- return type(self)(dims_, copy.copy(self._data), attrs_)
- else:
- cls_ = cast("type[Variable[Any, _DType]]", type(self))
- return cls_(dims_, data, attrs_)
-
- var_float: Variable[Any, np.dtype[np.float32]]
- var_float = Variable(("x",), np.array([1.5, 3.2], dtype=dtype_float))
- assert var_float.dtype == dtype_float
-
- var_int: Variable[Any, np.dtype[np.int8]]
- var_int = var_float._new(("x",), np.array([1, 3], dtype=dtype_int))
- assert var_int.dtype == dtype_int
-
-
-def test_replace_namedarray() -> None:
- dtype_float = np.dtype(np.float32)
- np_val: np.ndarray[Any, np.dtype[np.float32]]
- np_val = np.array([1.5, 3.2], dtype=dtype_float)
- np_val2: np.ndarray[Any, np.dtype[np.float32]]
- np_val2 = 2 * np_val
-
- narr_float: NamedArray[Any, np.dtype[np.float32]]
- narr_float = NamedArray(("x",), np_val)
- assert narr_float.dtype == dtype_float
-
- narr_float2: NamedArray[Any, np.dtype[np.float32]]
- narr_float2 = NamedArray(("x",), np_val2)
- assert narr_float2.dtype == dtype_float
-
- # Test with a subclass:
- class Variable(
- NamedArray[_ShapeType_co, _DType_co], Generic[_ShapeType_co, _DType_co]
- ):
- @overload
- def _new(
- self,
- dims: _DimsLike | Default = ...,
- data: duckarray[Any, _DType] = ...,
- attrs: _AttrsLike | Default = ...,
- ) -> Variable[Any, _DType]:
- ...
-
- @overload
- def _new(
- self,
- dims: _DimsLike | Default = ...,
- data: Default = ...,
- attrs: _AttrsLike | Default = ...,
- ) -> Variable[_ShapeType_co, _DType_co]:
- ...
-
- def _new(
- self,
- dims: _DimsLike | Default = _default,
- data: duckarray[Any, _DType] | Default = _default,
- attrs: _AttrsLike | Default = _default,
- ) -> Variable[Any, _DType] | Variable[_ShapeType_co, _DType_co]:
- dims_ = copy.copy(self._dims) if dims is _default else dims
-
- attrs_: Mapping[Any, Any] | None
- if attrs is _default:
- attrs_ = None if self._attrs is None else self._attrs.copy()
- else:
- attrs_ = attrs
+ actual = from_array(dims, data)
- if data is _default:
- return type(self)(dims_, copy.copy(self._data), attrs_)
+ assert np.array_equal(np.asarray(actual.data), expected)
+
+ def test_from_array_with_masked_array(self) -> None:
+ masked_array: np.ndarray[Any, np.dtype[np.generic]]
+ masked_array = np.ma.array([1, 2, 3], mask=[False, True, False]) # type: ignore[no-untyped-call]
+ with pytest.raises(NotImplementedError):
+ from_array(("x",), masked_array)
+
+ def test_from_array_with_0d_object(self) -> None:
+ data = np.empty((), dtype=object)
+ data[()] = (10, 12, 12)
+ narr = from_array((), data)
+ np.array_equal(np.asarray(narr.data), data)
+
+ # TODO: Make xr.core.indexing.ExplicitlyIndexed pass as a subclass of_arrayfunction_or_api
+ # and remove this test.
+ def test_from_array_with_explicitly_indexed(
+ self, random_inputs: np.ndarray[Any, Any]
+ ) -> None:
+ array: CustomArray[Any, Any]
+ array = CustomArray(random_inputs)
+ output: NamedArray[Any, Any]
+ output = from_array(("x", "y", "z"), array)
+ assert isinstance(output.data, np.ndarray)
+
+ array2: CustomArrayIndexable[Any, Any]
+ array2 = CustomArrayIndexable(random_inputs)
+ output2: NamedArray[Any, Any]
+ output2 = from_array(("x", "y", "z"), array2)
+ assert isinstance(output2.data, CustomArrayIndexable)
+
+ def test_real_and_imag(self) -> None:
+ expected_real: np.ndarray[Any, np.dtype[np.float64]]
+ expected_real = np.arange(3, dtype=np.float64)
+
+ expected_imag: np.ndarray[Any, np.dtype[np.float64]]
+ expected_imag = -np.arange(3, dtype=np.float64)
+
+ arr: np.ndarray[Any, np.dtype[np.complex128]]
+ arr = expected_real + 1j * expected_imag
+
+ named_array: NamedArray[Any, np.dtype[np.complex128]]
+ named_array = NamedArray(["x"], arr)
+
+ actual_real: duckarray[Any, np.dtype[np.float64]] = named_array.real.data
+ assert np.array_equal(np.asarray(actual_real), expected_real)
+ assert actual_real.dtype == expected_real.dtype
+
+ actual_imag: duckarray[Any, np.dtype[np.float64]] = named_array.imag.data
+ assert np.array_equal(np.asarray(actual_imag), expected_imag)
+ assert actual_imag.dtype == expected_imag.dtype
+
+ # Additional tests as per your original class-based code
+ @pytest.mark.parametrize(
+ "data, dtype",
+ [
+ ("foo", np.dtype("U3")),
+ (b"foo", np.dtype("S3")),
+ ],
+ )
+ def test_from_array_0d_string(self, data: Any, dtype: DTypeLike) -> None:
+ named_array: NamedArray[Any, Any]
+ named_array = from_array([], data)
+ assert named_array.data == data
+ assert named_array.dims == ()
+ assert named_array.sizes == {}
+ assert named_array.attrs == {}
+ assert named_array.ndim == 0
+ assert named_array.size == 1
+ assert named_array.dtype == dtype
+
+ def test_from_array_0d_object(self) -> None:
+ named_array: NamedArray[Any, Any]
+ named_array = from_array([], (10, 12, 12))
+ expected_data = np.empty((), dtype=object)
+ expected_data[()] = (10, 12, 12)
+ assert np.array_equal(np.asarray(named_array.data), expected_data)
+
+ assert named_array.dims == ()
+ assert named_array.sizes == {}
+ assert named_array.attrs == {}
+ assert named_array.ndim == 0
+ assert named_array.size == 1
+ assert named_array.dtype == np.dtype("O")
+
+ def test_from_array_0d_datetime(self) -> None:
+ named_array: NamedArray[Any, Any]
+ named_array = from_array([], np.datetime64("2000-01-01"))
+ assert named_array.dtype == np.dtype("datetime64[D]")
+
+ @pytest.mark.parametrize(
+ "timedelta, expected_dtype",
+ [
+ (np.timedelta64(1, "D"), np.dtype("timedelta64[D]")),
+ (np.timedelta64(1, "s"), np.dtype("timedelta64[s]")),
+ (np.timedelta64(1, "m"), np.dtype("timedelta64[m]")),
+ (np.timedelta64(1, "h"), np.dtype("timedelta64[h]")),
+ (np.timedelta64(1, "us"), np.dtype("timedelta64[us]")),
+ (np.timedelta64(1, "ns"), np.dtype("timedelta64[ns]")),
+ (np.timedelta64(1, "ps"), np.dtype("timedelta64[ps]")),
+ (np.timedelta64(1, "fs"), np.dtype("timedelta64[fs]")),
+ (np.timedelta64(1, "as"), np.dtype("timedelta64[as]")),
+ ],
+ )
+ def test_from_array_0d_timedelta(
+ self, timedelta: np.timedelta64, expected_dtype: np.dtype[np.timedelta64]
+ ) -> None:
+ named_array: NamedArray[Any, Any]
+ named_array = from_array([], timedelta)
+ assert named_array.dtype == expected_dtype
+ assert named_array.data == timedelta
+
+ @pytest.mark.parametrize(
+ "dims, data_shape, new_dims, raises",
+ [
+ (["x", "y", "z"], (2, 3, 4), ["a", "b", "c"], False),
+ (["x", "y", "z"], (2, 3, 4), ["a", "b"], True),
+ (["x", "y", "z"], (2, 4, 5), ["a", "b", "c", "d"], True),
+ ([], [], (), False),
+ ([], [], ("x",), True),
+ ],
+ )
+ def test_dims_setter(
+ self, dims: Any, data_shape: Any, new_dims: Any, raises: bool
+ ) -> None:
+ named_array: NamedArray[Any, Any]
+ named_array = NamedArray(dims, np.asarray(np.random.random(data_shape)))
+ assert named_array.dims == tuple(dims)
+ if raises:
+ with pytest.raises(ValueError):
+ named_array.dims = new_dims
+ else:
+ named_array.dims = new_dims
+ assert named_array.dims == tuple(new_dims)
+
+ def test_duck_array_class(
+ self,
+ ) -> None:
+ def test_duck_array_typevar(
+ a: duckarray[Any, _DType]
+ ) -> duckarray[Any, _DType]:
+ # Mypy checks a is valid:
+ b: duckarray[Any, _DType] = a
+
+ # Runtime check if valid:
+ if isinstance(b, _arrayfunction_or_api):
+ return b
else:
- cls_ = cast("type[Variable[Any, _DType]]", type(self))
- return cls_(dims_, data, attrs_)
-
- var_float: Variable[Any, np.dtype[np.float32]]
- var_float = Variable(("x",), np_val)
- assert var_float.dtype == dtype_float
-
- var_float2: Variable[Any, np.dtype[np.float32]]
- var_float2 = var_float._replace(("x",), np_val2)
- assert var_float2.dtype == dtype_float
+ raise TypeError(
+ f"a ({type(a)}) is not a valid _arrayfunction or _arrayapi"
+ )
+
+ numpy_a: NDArray[np.int64]
+ numpy_a = np.array([2.1, 4], dtype=np.dtype(np.int64))
+ test_duck_array_typevar(numpy_a)
+
+ masked_a: np.ma.MaskedArray[Any, np.dtype[np.int64]]
+ masked_a = np.ma.asarray([2.1, 4], dtype=np.dtype(np.int64)) # type: ignore[no-untyped-call]
+ test_duck_array_typevar(masked_a)
+
+ custom_a: CustomArrayIndexable[Any, np.dtype[np.int64]]
+ custom_a = CustomArrayIndexable(numpy_a)
+ test_duck_array_typevar(custom_a)
+
+ # Test numpy's array api:
+ with warnings.catch_warnings():
+ warnings.filterwarnings(
+ "ignore",
+ r"The numpy.array_api submodule is still experimental",
+ category=UserWarning,
+ )
+ import numpy.array_api as nxp
+
+ # TODO: nxp doesn't use dtype typevars, so can only use Any for the moment:
+ arrayapi_a: duckarray[Any, Any] # duckarray[Any, np.dtype[np.int64]]
+ arrayapi_a = nxp.asarray([2.1, 4], dtype=np.dtype(np.int64))
+ test_duck_array_typevar(arrayapi_a)
+
+ def test_new_namedarray(self) -> None:
+ dtype_float = np.dtype(np.float32)
+ narr_float: NamedArray[Any, np.dtype[np.float32]]
+ narr_float = NamedArray(("x",), np.array([1.5, 3.2], dtype=dtype_float))
+ assert narr_float.dtype == dtype_float
+
+ dtype_int = np.dtype(np.int8)
+ narr_int: NamedArray[Any, np.dtype[np.int8]]
+ narr_int = narr_float._new(("x",), np.array([1, 3], dtype=dtype_int))
+ assert narr_int.dtype == dtype_int
+
+ # Test with a subclass:
+ class Variable(
+ NamedArray[_ShapeType_co, _DType_co], Generic[_ShapeType_co, _DType_co]
+ ):
+ @overload
+ def _new(
+ self,
+ dims: _DimsLike | Default = ...,
+ data: duckarray[Any, _DType] = ...,
+ attrs: _AttrsLike | Default = ...,
+ ) -> Variable[Any, _DType]:
+ ...
+
+ @overload
+ def _new(
+ self,
+ dims: _DimsLike | Default = ...,
+ data: Default = ...,
+ attrs: _AttrsLike | Default = ...,
+ ) -> Variable[_ShapeType_co, _DType_co]:
+ ...
+
+ def _new(
+ self,
+ dims: _DimsLike | Default = _default,
+ data: duckarray[Any, _DType] | Default = _default,
+ attrs: _AttrsLike | Default = _default,
+ ) -> Variable[Any, _DType] | Variable[_ShapeType_co, _DType_co]:
+ dims_ = copy.copy(self._dims) if dims is _default else dims
+
+ attrs_: Mapping[Any, Any] | None
+ if attrs is _default:
+ attrs_ = None if self._attrs is None else self._attrs.copy()
+ else:
+ attrs_ = attrs
+
+ if data is _default:
+ return type(self)(dims_, copy.copy(self._data), attrs_)
+ else:
+ cls_ = cast("type[Variable[Any, _DType]]", type(self))
+ return cls_(dims_, data, attrs_)
+
+ var_float: Variable[Any, np.dtype[np.float32]]
+ var_float = Variable(("x",), np.array([1.5, 3.2], dtype=dtype_float))
+ assert var_float.dtype == dtype_float
+
+ var_int: Variable[Any, np.dtype[np.int8]]
+ var_int = var_float._new(("x",), np.array([1, 3], dtype=dtype_int))
+ assert var_int.dtype == dtype_int
+
+ def test_replace_namedarray(self) -> None:
+ dtype_float = np.dtype(np.float32)
+ np_val: np.ndarray[Any, np.dtype[np.float32]]
+ np_val = np.array([1.5, 3.2], dtype=dtype_float)
+ np_val2: np.ndarray[Any, np.dtype[np.float32]]
+ np_val2 = 2 * np_val
+
+ narr_float: NamedArray[Any, np.dtype[np.float32]]
+ narr_float = NamedArray(("x",), np_val)
+ assert narr_float.dtype == dtype_float
+
+ narr_float2: NamedArray[Any, np.dtype[np.float32]]
+ narr_float2 = NamedArray(("x",), np_val2)
+ assert narr_float2.dtype == dtype_float
+
+ # Test with a subclass:
+ class Variable(
+ NamedArray[_ShapeType_co, _DType_co], Generic[_ShapeType_co, _DType_co]
+ ):
+ @overload
+ def _new(
+ self,
+ dims: _DimsLike | Default = ...,
+ data: duckarray[Any, _DType] = ...,
+ attrs: _AttrsLike | Default = ...,
+ ) -> Variable[Any, _DType]:
+ ...
+
+ @overload
+ def _new(
+ self,
+ dims: _DimsLike | Default = ...,
+ data: Default = ...,
+ attrs: _AttrsLike | Default = ...,
+ ) -> Variable[_ShapeType_co, _DType_co]:
+ ...
+
+ def _new(
+ self,
+ dims: _DimsLike | Default = _default,
+ data: duckarray[Any, _DType] | Default = _default,
+ attrs: _AttrsLike | Default = _default,
+ ) -> Variable[Any, _DType] | Variable[_ShapeType_co, _DType_co]:
+ dims_ = copy.copy(self._dims) if dims is _default else dims
+
+ attrs_: Mapping[Any, Any] | None
+ if attrs is _default:
+ attrs_ = None if self._attrs is None else self._attrs.copy()
+ else:
+ attrs_ = attrs
+
+ if data is _default:
+ return type(self)(dims_, copy.copy(self._data), attrs_)
+ else:
+ cls_ = cast("type[Variable[Any, _DType]]", type(self))
+ return cls_(dims_, data, attrs_)
+
+ var_float: Variable[Any, np.dtype[np.float32]]
+ var_float = Variable(("x",), np_val)
+ assert var_float.dtype == dtype_float
+
+ var_float2: Variable[Any, np.dtype[np.float32]]
+ var_float2 = var_float._replace(("x",), np_val2)
+ assert var_float2.dtype == dtype_float
+
+ def test_warn_on_repeated_dimension_names(self) -> None:
+ with pytest.warns(UserWarning, match="Duplicate dimension names"):
+ NamedArray(("x", "x"), np.arange(4).reshape(2, 2))
diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py
index 31c23955b02..697db9c5e80 100644
--- a/xarray/tests/test_plot.py
+++ b/xarray/tests/test_plot.py
@@ -787,12 +787,17 @@ def test_plot_nans(self) -> None:
self.darray[1] = np.nan
self.darray.plot.line()
- def test_x_ticks_are_rotated_for_time(self) -> None:
+ def test_dates_are_concise(self) -> None:
+ import matplotlib.dates as mdates
+
time = pd.date_range("2000-01-01", "2000-01-10")
a = DataArray(np.arange(len(time)), [("t", time)])
a.plot.line()
- rotation = plt.gca().get_xticklabels()[0].get_rotation()
- assert rotation != 0
+
+ ax = plt.gca()
+
+ assert isinstance(ax.xaxis.get_major_locator(), mdates.AutoDateLocator)
+ assert isinstance(ax.xaxis.get_major_formatter(), mdates.ConciseDateFormatter)
def test_xyincrease_false_changes_axes(self) -> None:
self.darray.plot.line(xincrease=False, yincrease=False)
@@ -1356,12 +1361,17 @@ def test_xyincrease_true_changes_axes(self) -> None:
diffs = xlim[0] - 0, xlim[1] - 14, ylim[0] - 0, ylim[1] - 9
assert all(abs(x) < 1 for x in diffs)
- def test_x_ticks_are_rotated_for_time(self) -> None:
+ def test_dates_are_concise(self) -> None:
+ import matplotlib.dates as mdates
+
time = pd.date_range("2000-01-01", "2000-01-10")
a = DataArray(np.random.randn(2, len(time)), [("xx", [1, 2]), ("t", time)])
- a.plot(x="t")
- rotation = plt.gca().get_xticklabels()[0].get_rotation()
- assert rotation != 0
+ self.plotfunc(a, x="t")
+
+ ax = plt.gca()
+
+ assert isinstance(ax.xaxis.get_major_locator(), mdates.AutoDateLocator)
+ assert isinstance(ax.xaxis.get_major_formatter(), mdates.ConciseDateFormatter)
def test_plot_nans(self) -> None:
x1 = self.darray[:5]
@@ -1888,6 +1898,25 @@ def test_interval_breaks_logspace(self) -> None:
class TestImshow(Common2dMixin, PlotTestCase):
plotfunc = staticmethod(xplt.imshow)
+ @pytest.mark.xfail(
+ reason=(
+ "Failing inside matplotlib. Should probably be fixed upstream because "
+ "other plot functions can handle it. "
+ "Remove this test when it works, already in Common2dMixin"
+ )
+ )
+ def test_dates_are_concise(self) -> None:
+ import matplotlib.dates as mdates
+
+ time = pd.date_range("2000-01-01", "2000-01-10")
+ a = DataArray(np.random.randn(2, len(time)), [("xx", [1, 2]), ("t", time)])
+ self.plotfunc(a, x="t")
+
+ ax = plt.gca()
+
+ assert isinstance(ax.xaxis.get_major_locator(), mdates.AutoDateLocator)
+ assert isinstance(ax.xaxis.get_major_formatter(), mdates.ConciseDateFormatter)
+
@pytest.mark.slow
def test_imshow_called(self) -> None:
# Having both statements ensures the test works properly
@@ -2032,6 +2061,25 @@ class TestSurface(Common2dMixin, PlotTestCase):
plotfunc = staticmethod(xplt.surface)
subplot_kws = {"projection": "3d"}
+ @pytest.mark.xfail(
+ reason=(
+ "Failing inside matplotlib. Should probably be fixed upstream because "
+ "other plot functions can handle it. "
+ "Remove this test when it works, already in Common2dMixin"
+ )
+ )
+ def test_dates_are_concise(self) -> None:
+ import matplotlib.dates as mdates
+
+ time = pd.date_range("2000-01-01", "2000-01-10")
+ a = DataArray(np.random.randn(2, len(time)), [("xx", [1, 2]), ("t", time)])
+ self.plotfunc(a, x="t")
+
+ ax = plt.gca()
+
+ assert isinstance(ax.xaxis.get_major_locator(), mdates.AutoDateLocator)
+ assert isinstance(ax.xaxis.get_major_formatter(), mdates.ConciseDateFormatter)
+
def test_primitive_artist_returned(self) -> None:
artist = self.plotmethod()
assert isinstance(artist, mpl_toolkits.mplot3d.art3d.Poly3DCollection)
@@ -3324,3 +3372,16 @@ def test_plot1d_default_rcparams() -> None:
np.testing.assert_allclose(
ax.collections[0].get_edgecolor(), mpl.colors.to_rgba_array("k")
)
+
+
+@requires_matplotlib
+def test_plot1d_filtered_nulls() -> None:
+ ds = xr.tutorial.scatter_example_dataset(seed=42)
+ y = ds.y.where(ds.y > 0.2)
+ expected = y.notnull().sum().item()
+
+ with figure_context():
+ pc = y.plot.scatter()
+ actual = pc.get_offsets().shape[0]
+
+ assert expected == actual
diff --git a/xarray/tests/test_plugins.py b/xarray/tests/test_plugins.py
index 1af255d30bb..b518c973d3a 100644
--- a/xarray/tests/test_plugins.py
+++ b/xarray/tests/test_plugins.py
@@ -218,28 +218,29 @@ def test_lazy_import() -> None:
When importing xarray these should not be imported as well.
Only when running code for the first time that requires them.
"""
- blacklisted = [
+ deny_list = [
+ "cubed",
+ "cupy",
+ # "dask", # TODO: backends.locks is not lazy yet :(
+ "dask.array",
+ "dask.distributed",
+ "flox",
"h5netcdf",
+ "matplotlib",
+ "nc_time_axis",
"netCDF4",
- "pydap",
"Nio",
+ "numbagg",
+ "pint",
+ "pydap",
"scipy",
- "zarr",
- "matplotlib",
- "nc_time_axis",
- "flox",
- # "dask", # TODO: backends.locks is not lazy yet :(
- "dask.array",
- "dask.distributed",
"sparse",
- "cupy",
- "pint",
- "cubed",
+ "zarr",
]
# ensure that none of the above modules has been imported before
modules_backup = {}
for pkg in list(sys.modules.keys()):
- for mod in blacklisted + ["xarray"]:
+ for mod in deny_list + ["xarray"]:
if pkg.startswith(mod):
modules_backup[pkg] = sys.modules[pkg]
del sys.modules[pkg]
@@ -255,7 +256,7 @@ def test_lazy_import() -> None:
# lazy loaded are loaded when importing xarray
is_imported = set()
for pkg in sys.modules:
- for mod in blacklisted:
+ for mod in deny_list:
if pkg.startswith(mod):
is_imported.add(mod)
break
diff --git a/xarray/tests/test_rolling.py b/xarray/tests/test_rolling.py
index cb7b723a208..7cb2cd70d29 100644
--- a/xarray/tests/test_rolling.py
+++ b/xarray/tests/test_rolling.py
@@ -10,7 +10,6 @@
from xarray import DataArray, Dataset, set_options
from xarray.tests import (
assert_allclose,
- assert_array_equal,
assert_equal,
assert_identical,
has_dask,
@@ -24,6 +23,44 @@
]
+@pytest.fixture(params=["numbagg", "bottleneck"])
+def compute_backend(request):
+ if request.param == "bottleneck":
+ options = dict(use_bottleneck=True, use_numbagg=False)
+ elif request.param == "numbagg":
+ options = dict(use_bottleneck=False, use_numbagg=True)
+ else:
+ raise ValueError
+
+ with xr.set_options(**options):
+ yield request.param
+
+
+@pytest.mark.parametrize("func", ["mean", "sum"])
+@pytest.mark.parametrize("min_periods", [1, 10])
+def test_cumulative(d, func, min_periods) -> None:
+ # One dim
+ result = getattr(d.cumulative("z", min_periods=min_periods), func)()
+ expected = getattr(d.rolling(z=d["z"].size, min_periods=min_periods), func)()
+ assert_identical(result, expected)
+
+ # Multiple dim
+ result = getattr(d.cumulative(["z", "x"], min_periods=min_periods), func)()
+ expected = getattr(
+ d.rolling(z=d["z"].size, x=d["x"].size, min_periods=min_periods),
+ func,
+ )()
+ assert_identical(result, expected)
+
+
+def test_cumulative_vs_cum(d) -> None:
+ result = d.cumulative("z").sum()
+ expected = d.cumsum("z")
+ # cumsum drops the coord of the dimension; cumulative doesn't
+ expected = expected.assign_coords(z=result["z"])
+ assert_identical(result, expected)
+
+
class TestDataArrayRolling:
@pytest.mark.parametrize("da", (1, 2), indirect=True)
@pytest.mark.parametrize("center", [True, False])
@@ -87,9 +124,10 @@ def test_rolling_properties(self, da) -> None:
@pytest.mark.parametrize("center", (True, False, None))
@pytest.mark.parametrize("min_periods", (1, None))
@pytest.mark.parametrize("backend", ["numpy"], indirect=True)
- def test_rolling_wrapped_bottleneck(self, da, name, center, min_periods) -> None:
+ def test_rolling_wrapped_bottleneck(
+ self, da, name, center, min_periods, compute_backend
+ ) -> None:
bn = pytest.importorskip("bottleneck", minversion="1.1")
-
# Test all bottleneck functions
rolling_obj = da.rolling(time=7, min_periods=min_periods)
@@ -98,7 +136,9 @@ def test_rolling_wrapped_bottleneck(self, da, name, center, min_periods) -> None
expected = getattr(bn, func_name)(
da.values, window=7, axis=1, min_count=min_periods
)
- assert_array_equal(actual.values, expected)
+
+ # Using assert_allclose because we get tiny (1e-17) differences in numbagg.
+ np.testing.assert_allclose(actual.values, expected)
with pytest.warns(DeprecationWarning, match="Reductions are applied"):
getattr(rolling_obj, name)(dim="time")
@@ -106,7 +146,8 @@ def test_rolling_wrapped_bottleneck(self, da, name, center, min_periods) -> None
# Test center
rolling_obj = da.rolling(time=7, center=center)
actual = getattr(rolling_obj, name)()["time"]
- assert_equal(actual, da["time"])
+ # Using assert_allclose because we get tiny (1e-17) differences in numbagg.
+ assert_allclose(actual, da["time"])
@requires_dask
@pytest.mark.parametrize("name", ("mean", "count"))
@@ -153,7 +194,9 @@ def test_rolling_wrapped_dask_nochunk(self, center) -> None:
@pytest.mark.parametrize("center", (True, False))
@pytest.mark.parametrize("min_periods", (None, 1, 2, 3))
@pytest.mark.parametrize("window", (1, 2, 3, 4))
- def test_rolling_pandas_compat(self, center, window, min_periods) -> None:
+ def test_rolling_pandas_compat(
+ self, center, window, min_periods, compute_backend
+ ) -> None:
s = pd.Series(np.arange(10))
da = DataArray.from_series(s)
@@ -203,7 +246,9 @@ def test_rolling_construct(self, center: bool, window: int) -> None:
@pytest.mark.parametrize("min_periods", (None, 1, 2, 3))
@pytest.mark.parametrize("window", (1, 2, 3, 4))
@pytest.mark.parametrize("name", ("sum", "mean", "std", "max"))
- def test_rolling_reduce(self, da, center, min_periods, window, name) -> None:
+ def test_rolling_reduce(
+ self, da, center, min_periods, window, name, compute_backend
+ ) -> None:
if min_periods is not None and window < min_periods:
min_periods = window
@@ -217,13 +262,15 @@ def test_rolling_reduce(self, da, center, min_periods, window, name) -> None:
actual = rolling_obj.reduce(getattr(np, "nan%s" % name))
expected = getattr(rolling_obj, name)()
assert_allclose(actual, expected)
- assert actual.dims == expected.dims
+ assert actual.sizes == expected.sizes
@pytest.mark.parametrize("center", (True, False))
@pytest.mark.parametrize("min_periods", (None, 1, 2, 3))
@pytest.mark.parametrize("window", (1, 2, 3, 4))
@pytest.mark.parametrize("name", ("sum", "max"))
- def test_rolling_reduce_nonnumeric(self, center, min_periods, window, name) -> None:
+ def test_rolling_reduce_nonnumeric(
+ self, center, min_periods, window, name, compute_backend
+ ) -> None:
da = DataArray(
[0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims="time"
).isnull()
@@ -237,9 +284,9 @@ def test_rolling_reduce_nonnumeric(self, center, min_periods, window, name) -> N
actual = rolling_obj.reduce(getattr(np, "nan%s" % name))
expected = getattr(rolling_obj, name)()
assert_allclose(actual, expected)
- assert actual.dims == expected.dims
+ assert actual.sizes == expected.sizes
- def test_rolling_count_correct(self) -> None:
+ def test_rolling_count_correct(self, compute_backend) -> None:
da = DataArray([0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims="time")
kwargs: list[dict[str, Any]] = [
@@ -279,7 +326,9 @@ def test_rolling_count_correct(self) -> None:
@pytest.mark.parametrize("center", (True, False))
@pytest.mark.parametrize("min_periods", (None, 1))
@pytest.mark.parametrize("name", ("sum", "mean", "max"))
- def test_ndrolling_reduce(self, da, center, min_periods, name) -> None:
+ def test_ndrolling_reduce(
+ self, da, center, min_periods, name, compute_backend
+ ) -> None:
rolling_obj = da.rolling(time=3, x=2, center=center, min_periods=min_periods)
actual = getattr(rolling_obj, name)()
@@ -291,7 +340,7 @@ def test_ndrolling_reduce(self, da, center, min_periods, name) -> None:
)()
assert_allclose(actual, expected)
- assert actual.dims == expected.dims
+ assert actual.sizes == expected.sizes
if name in ["mean"]:
# test our reimplementation of nanmean using np.nanmean
@@ -560,7 +609,7 @@ def test_rolling_properties(self, ds) -> None:
@pytest.mark.parametrize("key", ("z1", "z2"))
@pytest.mark.parametrize("backend", ["numpy"], indirect=True)
def test_rolling_wrapped_bottleneck(
- self, ds, name, center, min_periods, key
+ self, ds, name, center, min_periods, key, compute_backend
) -> None:
bn = pytest.importorskip("bottleneck", minversion="1.1")
@@ -577,12 +626,12 @@ def test_rolling_wrapped_bottleneck(
)
else:
raise ValueError
- assert_array_equal(actual[key].values, expected)
+ np.testing.assert_allclose(actual[key].values, expected)
# Test center
rolling_obj = ds.rolling(time=7, center=center)
actual = getattr(rolling_obj, name)()["time"]
- assert_equal(actual, ds["time"])
+ assert_allclose(actual, ds["time"])
@pytest.mark.parametrize("center", (True, False))
@pytest.mark.parametrize("min_periods", (None, 1, 2, 3))
@@ -700,7 +749,7 @@ def test_rolling_reduce(self, ds, center, min_periods, window, name) -> None:
actual = rolling_obj.reduce(getattr(np, "nan%s" % name))
expected = getattr(rolling_obj, name)()
assert_allclose(actual, expected)
- assert ds.dims == actual.dims
+ assert ds.sizes == actual.sizes
# make sure the order of data_var are not changed.
assert list(ds.data_vars.keys()) == list(actual.data_vars.keys())
@@ -727,7 +776,7 @@ def test_ndrolling_reduce(self, ds, center, min_periods, name, dask) -> None:
name,
)()
assert_allclose(actual, expected)
- assert actual.dims == expected.dims
+ assert actual.sizes == expected.sizes
# Do it in the opposite order
expected = getattr(
@@ -738,7 +787,7 @@ def test_ndrolling_reduce(self, ds, center, min_periods, name, dask) -> None:
)()
assert_allclose(actual, expected)
- assert actual.dims == expected.dims
+ assert actual.sizes == expected.sizes
@pytest.mark.parametrize("center", (True, False, (True, False)))
@pytest.mark.parametrize("fill_value", (np.nan, 0.0))
diff --git a/xarray/tests/test_strategies.py b/xarray/tests/test_strategies.py
new file mode 100644
index 00000000000..44f0d56cde8
--- /dev/null
+++ b/xarray/tests/test_strategies.py
@@ -0,0 +1,271 @@
+import numpy as np
+import numpy.testing as npt
+import pytest
+
+pytest.importorskip("hypothesis")
+# isort: split
+
+import hypothesis.extra.numpy as npst
+import hypothesis.strategies as st
+from hypothesis import given
+from hypothesis.extra.array_api import make_strategies_namespace
+
+from xarray.core.variable import Variable
+from xarray.testing.strategies import (
+ attrs,
+ dimension_names,
+ dimension_sizes,
+ supported_dtypes,
+ unique_subset_of,
+ variables,
+)
+from xarray.tests import requires_numpy_array_api
+
+ALLOWED_ATTRS_VALUES_TYPES = (int, bool, str, np.ndarray)
+
+
+class TestDimensionNamesStrategy:
+ @given(dimension_names())
+ def test_types(self, dims):
+ assert isinstance(dims, list)
+ for d in dims:
+ assert isinstance(d, str)
+
+ @given(dimension_names())
+ def test_unique(self, dims):
+ assert len(set(dims)) == len(dims)
+
+ @given(st.data(), st.tuples(st.integers(0, 10), st.integers(0, 10)).map(sorted))
+ def test_number_of_dims(self, data, ndims):
+ min_dims, max_dims = ndims
+ dim_names = data.draw(dimension_names(min_dims=min_dims, max_dims=max_dims))
+ assert isinstance(dim_names, list)
+ assert min_dims <= len(dim_names) <= max_dims
+
+
+class TestDimensionSizesStrategy:
+ @given(dimension_sizes())
+ def test_types(self, dims):
+ assert isinstance(dims, dict)
+ for d, n in dims.items():
+ assert isinstance(d, str)
+ assert len(d) >= 1
+
+ assert isinstance(n, int)
+ assert n >= 0
+
+ @given(st.data(), st.tuples(st.integers(0, 10), st.integers(0, 10)).map(sorted))
+ def test_number_of_dims(self, data, ndims):
+ min_dims, max_dims = ndims
+ dim_sizes = data.draw(dimension_sizes(min_dims=min_dims, max_dims=max_dims))
+ assert isinstance(dim_sizes, dict)
+ assert min_dims <= len(dim_sizes) <= max_dims
+
+ @given(st.data())
+ def test_restrict_names(self, data):
+ capitalized_names = st.text(st.characters(), min_size=1).map(str.upper)
+ dim_sizes = data.draw(dimension_sizes(dim_names=capitalized_names))
+ for dim in dim_sizes.keys():
+ assert dim.upper() == dim
+
+
+def check_dict_values(dictionary: dict, allowed_attrs_values_types) -> bool:
+ """Helper function to assert that all values in recursive dict match one of a set of types."""
+ for key, value in dictionary.items():
+ if isinstance(value, allowed_attrs_values_types) or value is None:
+ continue
+ elif isinstance(value, dict):
+ # If the value is a dictionary, recursively check it
+ if not check_dict_values(value, allowed_attrs_values_types):
+ return False
+ else:
+ # If the value is not an integer or a dictionary, it's not valid
+ return False
+ return True
+
+
+class TestAttrsStrategy:
+ @given(attrs())
+ def test_type(self, attrs):
+ assert isinstance(attrs, dict)
+ check_dict_values(attrs, ALLOWED_ATTRS_VALUES_TYPES)
+
+
+class TestVariablesStrategy:
+ @given(variables())
+ def test_given_nothing(self, var):
+ assert isinstance(var, Variable)
+
+ @given(st.data())
+ def test_given_incorrect_types(self, data):
+ with pytest.raises(TypeError, match="dims must be provided as a"):
+ data.draw(variables(dims=["x", "y"])) # type: ignore[arg-type]
+
+ with pytest.raises(TypeError, match="dtype must be provided as a"):
+ data.draw(variables(dtype=np.dtype("int32"))) # type: ignore[arg-type]
+
+ with pytest.raises(TypeError, match="attrs must be provided as a"):
+ data.draw(variables(attrs=dict())) # type: ignore[arg-type]
+
+ with pytest.raises(TypeError, match="Callable"):
+ data.draw(variables(array_strategy_fn=np.array([0]))) # type: ignore[arg-type]
+
+ @given(st.data(), dimension_names())
+ def test_given_fixed_dim_names(self, data, fixed_dim_names):
+ var = data.draw(variables(dims=st.just(fixed_dim_names)))
+
+ assert list(var.dims) == fixed_dim_names
+
+ @given(st.data(), dimension_sizes())
+ def test_given_fixed_dim_sizes(self, data, dim_sizes):
+ var = data.draw(variables(dims=st.just(dim_sizes)))
+
+ assert var.dims == tuple(dim_sizes.keys())
+ assert var.shape == tuple(dim_sizes.values())
+
+ @given(st.data(), supported_dtypes())
+ def test_given_fixed_dtype(self, data, dtype):
+ var = data.draw(variables(dtype=st.just(dtype)))
+
+ assert var.dtype == dtype
+
+ @given(st.data(), npst.arrays(shape=npst.array_shapes(), dtype=supported_dtypes()))
+ def test_given_fixed_data_dims_and_dtype(self, data, arr):
+ def fixed_array_strategy_fn(*, shape=None, dtype=None):
+ """The fact this ignores shape and dtype is only okay because compatible shape & dtype will be passed separately."""
+ return st.just(arr)
+
+ dim_names = data.draw(dimension_names(min_dims=arr.ndim, max_dims=arr.ndim))
+ dim_sizes = {name: size for name, size in zip(dim_names, arr.shape)}
+
+ var = data.draw(
+ variables(
+ array_strategy_fn=fixed_array_strategy_fn,
+ dims=st.just(dim_sizes),
+ dtype=st.just(arr.dtype),
+ )
+ )
+
+ npt.assert_equal(var.data, arr)
+ assert var.dtype == arr.dtype
+
+ @given(st.data(), st.integers(0, 3))
+ def test_given_array_strat_arbitrary_size_and_arbitrary_data(self, data, ndims):
+ dim_names = data.draw(dimension_names(min_dims=ndims, max_dims=ndims))
+
+ def array_strategy_fn(*, shape=None, dtype=None):
+ return npst.arrays(shape=shape, dtype=dtype)
+
+ var = data.draw(
+ variables(
+ array_strategy_fn=array_strategy_fn,
+ dims=st.just(dim_names),
+ dtype=supported_dtypes(),
+ )
+ )
+
+ assert var.ndim == ndims
+
+ @given(st.data())
+ def test_catch_unruly_dtype_from_custom_array_strategy_fn(self, data):
+ def dodgy_array_strategy_fn(*, shape=None, dtype=None):
+ """Dodgy function which ignores the dtype it was passed"""
+ return npst.arrays(shape=shape, dtype=npst.floating_dtypes())
+
+ with pytest.raises(
+ ValueError, match="returned an array object with a different dtype"
+ ):
+ data.draw(
+ variables(
+ array_strategy_fn=dodgy_array_strategy_fn,
+ dtype=st.just(np.dtype("int32")),
+ )
+ )
+
+ @given(st.data())
+ def test_catch_unruly_shape_from_custom_array_strategy_fn(self, data):
+ def dodgy_array_strategy_fn(*, shape=None, dtype=None):
+ """Dodgy function which ignores the shape it was passed"""
+ return npst.arrays(shape=(3, 2), dtype=dtype)
+
+ with pytest.raises(
+ ValueError, match="returned an array object with a different shape"
+ ):
+ data.draw(
+ variables(
+ array_strategy_fn=dodgy_array_strategy_fn,
+ dims=st.just({"a": 2, "b": 1}),
+ dtype=supported_dtypes(),
+ )
+ )
+
+ @requires_numpy_array_api
+ @given(st.data())
+ def test_make_strategies_namespace(self, data):
+ """
+ Test not causing a hypothesis.InvalidArgument by generating a dtype that's not in the array API.
+
+ We still want to generate dtypes not in the array API by default, but this checks we don't accidentally override
+ the user's choice of dtypes with non-API-compliant ones.
+ """
+ from numpy import (
+ array_api as np_array_api, # requires numpy>=1.26.0, and we expect a UserWarning to be raised
+ )
+
+ np_array_api_st = make_strategies_namespace(np_array_api)
+
+ data.draw(
+ variables(
+ array_strategy_fn=np_array_api_st.arrays,
+ dtype=np_array_api_st.scalar_dtypes(),
+ )
+ )
+
+
+class TestUniqueSubsetOf:
+ @given(st.data())
+ def test_invalid(self, data):
+ with pytest.raises(TypeError, match="must be an Iterable or a Mapping"):
+ data.draw(unique_subset_of(0)) # type: ignore[call-overload]
+
+ with pytest.raises(ValueError, match="length-zero object"):
+ data.draw(unique_subset_of({}))
+
+ @given(st.data(), dimension_sizes(min_dims=1))
+ def test_mapping(self, data, dim_sizes):
+ subset_of_dim_sizes = data.draw(unique_subset_of(dim_sizes))
+
+ for dim, length in subset_of_dim_sizes.items():
+ assert dim in dim_sizes
+ assert dim_sizes[dim] == length
+
+ @given(st.data(), dimension_names(min_dims=1))
+ def test_iterable(self, data, dim_names):
+ subset_of_dim_names = data.draw(unique_subset_of(dim_names))
+
+ for dim in subset_of_dim_names:
+ assert dim in dim_names
+
+
+class TestReduction:
+ """
+ These tests are for checking that the examples given in the docs page on testing actually work.
+ """
+
+ @given(st.data(), variables(dims=dimension_names(min_dims=1)))
+ def test_mean(self, data, var):
+ """
+ Test that given a Variable of at least one dimension,
+ the mean of the Variable is always equal to the mean of the underlying array.
+ """
+
+ # specify arbitrary reduction along at least one dimension
+ reduction_dims = data.draw(unique_subset_of(var.dims, min_size=1))
+
+ # create expected result (using nanmean because arrays with Nans will be generated)
+ reduction_axes = tuple(var.get_axis_num(dim) for dim in reduction_dims)
+ expected = np.nanmean(var.data, axis=reduction_axes)
+
+ # assert property is always satisfied
+ result = var.mean(dim=reduction_dims).data
+ npt.assert_equal(expected, result)
diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py
index 8a73e435977..a2ae1e61cf2 100644
--- a/xarray/tests/test_variable.py
+++ b/xarray/tests/test_variable.py
@@ -1,7 +1,7 @@
from __future__ import annotations
import warnings
-from abc import ABC, abstractmethod
+from abc import ABC
from copy import copy, deepcopy
from datetime import datetime, timedelta
from textwrap import dedent
@@ -46,6 +46,7 @@
requires_sparse,
source_ndarray,
)
+from xarray.tests.test_namedarray import NamedArraySubclassobjects
dask_array_type = array_type("dask")
@@ -63,34 +64,11 @@ def var():
return Variable(dims=list("xyz"), data=np.random.rand(3, 4, 5))
-class VariableSubclassobjects(ABC):
- @abstractmethod
- def cls(self, *args, **kwargs) -> Variable:
- raise NotImplementedError
-
- def test_properties(self):
- data = 0.5 * np.arange(10)
- v = self.cls(["time"], data, {"foo": "bar"})
- assert v.dims == ("time",)
- assert_array_equal(v.values, data)
- assert v.dtype == float
- assert v.shape == (10,)
- assert v.size == 10
- assert v.sizes == {"time": 10}
- assert v.nbytes == 80
- assert v.ndim == 1
- assert len(v) == 10
- assert v.attrs == {"foo": "bar"}
-
- def test_attrs(self):
- v = self.cls(["time"], 0.5 * np.arange(10))
- assert v.attrs == {}
- attrs = {"foo": "bar"}
- v.attrs = attrs
- assert v.attrs == attrs
- assert isinstance(v.attrs, dict)
- v.attrs["foo"] = "baz"
- assert v.attrs["foo"] == "baz"
+class VariableSubclassobjects(NamedArraySubclassobjects, ABC):
+ @pytest.fixture
+ def target(self, data):
+ data = 0.5 * np.arange(10).reshape(2, 5)
+ return Variable(["x", "y"], data)
def test_getitem_dict(self):
v = self.cls(["x"], np.random.randn(5))
@@ -368,7 +346,7 @@ def test_1d_math(self, dtype: np.typing.DTypeLike) -> None:
assert_array_equal(v >> 2, x >> 2)
# binary ops with numpy arrays
assert_array_equal((v * x).values, x**2)
- assert_array_equal((x * v).values, x**2) # type: ignore[attr-defined] # TODO: Fix mypy thinking numpy takes priority, GH7780
+ assert_array_equal((x * v).values, x**2)
assert_array_equal(v - y, v - 1)
assert_array_equal(y - v, 1 - v)
if dtype is int:
@@ -1065,9 +1043,8 @@ def cls(self, *args, **kwargs) -> Variable:
def setup(self):
self.d = np.random.random((10, 3)).astype(np.float64)
- def test_data_and_values(self):
+ def test_values(self):
v = Variable(["time", "x"], self.d)
- assert_array_equal(v.data, self.d)
assert_array_equal(v.values, self.d)
assert source_ndarray(v.values) is self.d
with pytest.raises(ValueError):
@@ -1076,9 +1053,6 @@ def test_data_and_values(self):
d2 = np.random.random((10, 3))
v.values = d2
assert source_ndarray(v.values) is d2
- d3 = np.random.random((10, 3))
- v.data = d3
- assert source_ndarray(v.data) is d3
def test_numpy_same_methods(self):
v = Variable([], np.float32(0.0))
@@ -1731,6 +1705,7 @@ def test_broadcasting_math(self):
v * w[0], Variable(["a", "b", "c", "d"], np.einsum("ab,cd->abcd", x, y[0]))
)
+ @pytest.mark.filterwarnings("ignore:Duplicate dimension names")
def test_broadcasting_failures(self):
a = Variable(["x"], np.arange(10))
b = Variable(["x"], np.arange(5))
@@ -1878,9 +1853,20 @@ def test_quantile_out_of_bounds(self, q):
@requires_dask
@requires_bottleneck
- def test_rank_dask_raises(self):
- v = Variable(["x"], [3.0, 1.0, np.nan, 2.0, 4.0]).chunk(2)
- with pytest.raises(TypeError, match=r"arrays stored as dask"):
+ def test_rank_dask(self):
+ # Instead of a single test here, we could parameterize the other tests for both
+ # arrays. But this is sufficient.
+ v = Variable(
+ ["x", "y"], [[30.0, 1.0, np.nan, 20.0, 4.0], [30.0, 1.0, np.nan, 20.0, 4.0]]
+ ).chunk(x=1)
+ expected = Variable(
+ ["x", "y"], [[4.0, 1.0, np.nan, 3.0, 2.0], [4.0, 1.0, np.nan, 3.0, 2.0]]
+ )
+ assert_equal(v.rank("y").compute(), expected)
+
+ with pytest.raises(
+ ValueError, match=r" with dask='parallelized' consists of multiple chunks"
+ ):
v.rank("x")
def test_rank_use_bottleneck(self):
@@ -1912,7 +1898,8 @@ def test_rank(self):
v_expect = Variable(["x"], [0.75, 0.25, np.nan, 0.5, 1.0])
assert_equal(v.rank("x", pct=True), v_expect)
# invalid dim
- with pytest.raises(ValueError, match=r"not found"):
+ with pytest.raises(ValueError):
+ # apply_ufunc error message isn't great here — `ValueError: tuple.index(x): x not in tuple`
v.rank("y")
def test_big_endian_reduce(self):
diff --git a/xarray/util/deprecation_helpers.py b/xarray/util/deprecation_helpers.py
index 7b4cf901aa1..c620e45574e 100644
--- a/xarray/util/deprecation_helpers.py
+++ b/xarray/util/deprecation_helpers.py
@@ -36,6 +36,8 @@
from functools import wraps
from typing import Callable, TypeVar
+from xarray.core.utils import emit_user_level_warning
+
T = TypeVar("T", bound=Callable)
POSITIONAL_OR_KEYWORD = inspect.Parameter.POSITIONAL_OR_KEYWORD
@@ -115,3 +117,28 @@ def inner(*args, **kwargs):
return inner
return _decorator
+
+
+def deprecate_dims(func: T) -> T:
+ """
+ For functions that previously took `dims` as a kwarg, and have now transitioned to
+ `dim`. This decorator will issue a warning if `dims` is passed while forwarding it
+ to `dim`.
+ """
+
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ if "dims" in kwargs:
+ emit_user_level_warning(
+ "The `dims` argument has been renamed to `dim`, and will be removed "
+ "in the future. This renaming is taking place throughout xarray over the "
+ "next few releases.",
+ # Upgrade to `DeprecationWarning` in the future, when the renaming is complete.
+ PendingDeprecationWarning,
+ )
+ kwargs["dim"] = kwargs.pop("dims")
+ return func(*args, **kwargs)
+
+ # We're quite confident we're just returning `T` from this function, so it's fine to ignore typing
+ # within the function.
+ return wrapper # type: ignore