From db36c5c0cdee2f5313a81fdeca8a8ae5491d1c8f Mon Sep 17 00:00:00 2001 From: hazbottles Date: Fri, 3 Jan 2020 12:16:44 +0000 Subject: [PATCH 01/23] add multiindex level name checking to .rename() (#3658) * add multiindex level name checking to .rename() * update whats-new.rst --- doc/whats-new.rst | 2 ++ xarray/core/dataset.py | 9 ++++++++- xarray/tests/test_dataset.py | 8 ++++++++ 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 00d1c50780e..70853dbb730 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -62,6 +62,8 @@ Bug fixes By `Tom Augspurger `_. - Ensure :py:meth:`Dataset.quantile`, :py:meth:`DataArray.quantile` issue the correct error when ``q`` is out of bounds (:issue:`3634`) by `Mathias Hauser `_. +- :py:meth:`Dataset.rename`, :py:meth:`DataArray.rename` now check for conflicts with + MultiIndex level names. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 6be06fed117..032c66d3778 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -89,7 +89,13 @@ is_scalar, maybe_wrap_array, ) -from .variable import IndexVariable, Variable, as_variable, broadcast_variables +from .variable import ( + IndexVariable, + Variable, + as_variable, + broadcast_variables, + assert_unique_multiindex_level_names, +) if TYPE_CHECKING: from ..backends import AbstractDataStore, ZarrStore @@ -2780,6 +2786,7 @@ def rename( variables, coord_names, dims, indexes = self._rename_all( name_dict=name_dict, dims_dict=name_dict ) + assert_unique_multiindex_level_names(variables) return self._replace(variables, coord_names, dims=dims, indexes=indexes) def rename_dims( diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 7db1911621b..edda4cb2a58 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -2461,6 +2461,14 @@ def test_rename_vars(self): with pytest.raises(ValueError): original.rename_vars(names_dict_bad) + def test_rename_multiindex(self): + mindex = pd.MultiIndex.from_tuples( + [([1, 2]), ([3, 4])], names=["level0", "level1"] + ) + data = Dataset({}, {"x": mindex}) + with raises_regex(ValueError, "conflicting MultiIndex"): + data.rename({"x": "level0"}) + @requires_cftime def test_rename_does_not_change_CFTimeIndex_type(self): # make sure CFTimeIndex is not converted to DatetimeIndex #3522 From 8fc9ecedc87b5d878363b233e260c87fd632fa0f Mon Sep 17 00:00:00 2001 From: Riley Brady Date: Wed, 8 Jan 2020 10:50:03 -0700 Subject: [PATCH 02/23] Add map_blocks example to docs. (#3667) --- xarray/core/parallel.py | 42 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index dd6c67338d8..e4fb5803191 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -154,6 +154,48 @@ def map_blocks( -------- dask.array.map_blocks, xarray.apply_ufunc, xarray.Dataset.map_blocks, xarray.DataArray.map_blocks + + Examples + -------- + + Calculate an anomaly from climatology using ``.groupby()``. Using + ``xr.map_blocks()`` allows for parallel operations with knowledge of ``xarray``, + its indices, and its methods like ``.groupby()``. + + >>> def calculate_anomaly(da, groupby_type='time.month'): + ... # Necessary workaround to xarray's check with zero dimensions + ... # https://github.com/pydata/xarray/issues/3575 + ... if sum(da.shape) == 0: + ... return da + ... gb = da.groupby(groupby_type) + ... clim = gb.mean(dim='time') + ... return gb - clim + >>> time = xr.cftime_range('1990-01', '1992-01', freq='M') + >>> np.random.seed(123) + >>> array = xr.DataArray(np.random.rand(len(time)), + ... dims="time", coords=[time]).chunk() + >>> xr.map_blocks(calculate_anomaly, array).compute() + + array([ 0.12894847, 0.11323072, -0.0855964 , -0.09334032, 0.26848862, + 0.12382735, 0.22460641, 0.07650108, -0.07673453, -0.22865714, + -0.19063865, 0.0590131 , -0.12894847, -0.11323072, 0.0855964 , + 0.09334032, -0.26848862, -0.12382735, -0.22460641, -0.07650108, + 0.07673453, 0.22865714, 0.19063865, -0.0590131 ]) + Coordinates: + * time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 + + Note that one must explicitly use ``args=[]`` and ``kwargs={}`` to pass arguments + to the function being applied in ``xr.map_blocks()``: + + >>> xr.map_blocks(calculate_anomaly, array, kwargs={'groupby_type': 'time.year'}) + + array([ 0.15361741, -0.25671244, -0.31600032, 0.008463 , 0.1766172 , + -0.11974531, 0.43791243, 0.14197797, -0.06191987, -0.15073425, + -0.19967375, 0.18619794, -0.05100474, -0.42989909, -0.09153273, + 0.24841842, -0.30708526, -0.31412523, 0.04197439, 0.0422506 , + 0.14482397, 0.35985481, 0.23487834, 0.12144652]) + Coordinates: + * time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 """ def _wrapper(func, obj, to_array, args, kwargs): From 080caf4246fe2f4d6aa0c5dcb65a99b376fa669b Mon Sep 17 00:00:00 2001 From: keewis Date: Wed, 8 Jan 2020 19:27:29 +0100 Subject: [PATCH 03/23] Support swap_dims to dimension names that are not existing variables (#3636) * test that swapping to a non-existing name works * don't try to get a variable if the variable does not exist * don't add dimensions to coord_names if they are not existing variables * add whats-new.rst entry * update the documentation --- doc/whats-new.rst | 3 +++ xarray/core/dataarray.py | 10 ++++++++-- xarray/core/dataset.py | 17 +++++++++++++---- xarray/tests/test_dataarray.py | 5 +++++ xarray/tests/test_dataset.py | 6 ++++++ 5 files changed, 35 insertions(+), 6 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 70853dbb730..351424fbb9f 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -37,6 +37,9 @@ New Features - Added the ``count`` reduction method to both :py:class:`~core.rolling.DatasetCoarsen` and :py:class:`~core.rolling.DataArrayCoarsen` objects. (:pull:`3500`) By `Deepak Cherian `_ +- :py:meth:`Dataset.swap_dims` and :py:meth:`DataArray.swap_dims` + now allow swapping to dimension names that don't exist yet. (:pull:`3636`) + By `Justus Magin `_. - Extend :py:class:`core.accessor_dt.DatetimeAccessor` properties and support `.dt` accessor for timedelta via :py:class:`core.accessor_dt.TimedeltaAccessor` (:pull:`3612`) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 31aa4da57b2..cbd8d243385 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1480,8 +1480,7 @@ def swap_dims(self, dims_dict: Mapping[Hashable, Hashable]) -> "DataArray": ---------- dims_dict : dict-like Dictionary whose keys are current dimension names and whose values - are new names. Each value must already be a coordinate on this - array. + are new names. Returns ------- @@ -1504,6 +1503,13 @@ def swap_dims(self, dims_dict: Mapping[Hashable, Hashable]) -> "DataArray": Coordinates: x (y) >> arr.swap_dims({"x": "z"}) + + array([0, 1]) + Coordinates: + x (z) >> ds.swap_dims({"x": "z"}) + + Dimensions: (z: 2) + Coordinates: + x (z) Date: Thu, 9 Jan 2020 02:46:45 +0100 Subject: [PATCH 04/23] raise an error when renaming dimensions to existing names (#3645) * check we raise an error if the name already exists * raise if the new name already exists and point to swap_dims * update the documentation * whats-new.rst * fix the docstring of rename_dims Co-Authored-By: Deepak Cherian --- doc/whats-new.rst | 3 +++ xarray/core/dataset.py | 10 ++++++++-- xarray/tests/test_dataset.py | 3 +++ 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 351424fbb9f..5a9f2497ed6 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -65,6 +65,9 @@ Bug fixes By `Tom Augspurger `_. - Ensure :py:meth:`Dataset.quantile`, :py:meth:`DataArray.quantile` issue the correct error when ``q`` is out of bounds (:issue:`3634`) by `Mathias Hauser `_. +- Raise an error when trying to use :py:meth:`Dataset.rename_dims` to + rename to an existing name (:issue:`3438`, :pull:`3645`) + By `Justus Magin `_. - :py:meth:`Dataset.rename`, :py:meth:`DataArray.rename` now check for conflicts with MultiIndex level names. diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 607350c8101..ac0a923db78 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -2798,7 +2798,8 @@ def rename_dims( ---------- dims_dict : dict-like, optional Dictionary whose keys are current dimension names and - whose values are the desired names. + whose values are the desired names. The desired names must + not be the name of an existing dimension or Variable in the Dataset. **dims, optional Keyword form of ``dims_dict``. One of dims_dict or dims must be provided. @@ -2816,12 +2817,17 @@ def rename_dims( DataArray.rename """ dims_dict = either_dict_or_kwargs(dims_dict, dims, "rename_dims") - for k in dims_dict: + for k, v in dims_dict.items(): if k not in self.dims: raise ValueError( "cannot rename %r because it is not a " "dimension in this dataset" % k ) + if v in self.dims or v in self: + raise ValueError( + f"Cannot rename {k} to {v} because {v} already exists. " + "Try using swap_dims instead." + ) variables, coord_names, sizes, indexes = self._rename_all( name_dict={}, dims_dict=dims_dict diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 48d8c25b810..2220abbef31 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -2444,6 +2444,9 @@ def test_rename_dims(self): with pytest.raises(ValueError): original.rename_dims(dims_dict_bad) + with pytest.raises(ValueError): + original.rename_dims({"x": "z"}) + def test_rename_vars(self): original = Dataset({"x": ("x", [0, 1, 2]), "y": ("x", [10, 11, 12]), "z": 42}) expected = Dataset( From 24f9292114894621d8eb7a4eade6347538ce0d23 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 10 Jan 2020 16:10:56 +0000 Subject: [PATCH 05/23] Make dask names change when chunking Variables by different amounts. (#3584) * Make dask tokens change when chunking Variables by different amounts. When rechunking by the current chunk size, the dask token should not change. Add a __dask_tokenize__ method for ReprObject so that this behaviour is present when DataArrays are converted to temporary Datasets and back. Co-Authored-By: crusaderky Co-authored-by: crusaderky --- doc/whats-new.rst | 4 ++++ xarray/core/dataset.py | 5 ++++- xarray/core/utils.py | 7 ++++++- xarray/tests/test_dask.py | 18 +++++++++--------- xarray/tests/test_dataarray.py | 7 +++++++ xarray/tests/test_dataset.py | 16 ++++++++++++++++ 6 files changed, 46 insertions(+), 11 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 5a9f2497ed6..e1c4e3dd9ac 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -47,6 +47,7 @@ New Features Bug fixes ~~~~~~~~~ + - Fix :py:meth:`xarray.combine_by_coords` to allow for combining incomplete hypercubes of Datasets (:issue:`3648`). By `Ian Bolliger `_. @@ -91,6 +92,9 @@ Documentation Internal Changes ~~~~~~~~~~~~~~~~ +- Make sure dask names change when rechunking by different chunk sizes. Conversely, make sure they + stay the same when rechunking by the same chunk size. (:issue:`3350`) + By `Deepak Cherian `_. - 2x to 5x speed boost (on small arrays) for :py:meth:`Dataset.isel`, :py:meth:`DataArray.isel`, and :py:meth:`DataArray.__getitem__` when indexing by int, slice, list of int, scalar ndarray, or 1-dimensional ndarray. diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index ac0a923db78..129f0c0f7a7 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1754,7 +1754,10 @@ def maybe_chunk(name, var, chunks): if not chunks: chunks = None if var.ndim > 0: - token2 = tokenize(name, token if token else var._data) + # when rechunking by different amounts, make sure dask names change + # by provinding chunks as an input to tokenize. + # subtle bugs result otherwise. see GH3350 + token2 = tokenize(name, token if token else var._data, chunks) name2 = f"{name_prefix}{name}-{token2}" return var.chunk(chunks, name=name2, lock=lock) else: diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 6681375c18e..e335365d5ca 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -547,7 +547,12 @@ def __eq__(self, other) -> bool: return False def __hash__(self) -> int: - return hash((ReprObject, self._value)) + return hash((type(self), self._value)) + + def __dask_tokenize__(self): + from dask.base import normalize_token + + return normalize_token((type(self), self._value)) @contextlib.contextmanager diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index d0e2654eed3..cc554850839 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1083,7 +1083,7 @@ def func(obj): actual = xr.map_blocks(func, obj) expected = func(obj) assert_chunks_equal(expected.chunk(), actual) - xr.testing.assert_identical(actual.compute(), expected.compute()) + assert_identical(actual, expected) @pytest.mark.parametrize("obj", [make_da(), make_ds()]) @@ -1092,7 +1092,7 @@ def test_map_blocks_convert_args_to_list(obj): with raise_if_dask_computes(): actual = xr.map_blocks(operator.add, obj, [10]) assert_chunks_equal(expected.chunk(), actual) - xr.testing.assert_identical(actual.compute(), expected.compute()) + assert_identical(actual, expected) @pytest.mark.parametrize("obj", [make_da(), make_ds()]) @@ -1107,7 +1107,7 @@ def add_attrs(obj): with raise_if_dask_computes(): actual = xr.map_blocks(add_attrs, obj) - xr.testing.assert_identical(actual.compute(), expected.compute()) + assert_identical(actual, expected) def test_map_blocks_change_name(map_da): @@ -1120,7 +1120,7 @@ def change_name(obj): with raise_if_dask_computes(): actual = xr.map_blocks(change_name, map_da) - xr.testing.assert_identical(actual.compute(), expected.compute()) + assert_identical(actual, expected) @pytest.mark.parametrize("obj", [make_da(), make_ds()]) @@ -1129,7 +1129,7 @@ def test_map_blocks_kwargs(obj): with raise_if_dask_computes(): actual = xr.map_blocks(xr.full_like, obj, kwargs=dict(fill_value=np.nan)) assert_chunks_equal(expected.chunk(), actual) - xr.testing.assert_identical(actual.compute(), expected.compute()) + assert_identical(actual, expected) def test_map_blocks_to_array(map_ds): @@ -1137,7 +1137,7 @@ def test_map_blocks_to_array(map_ds): actual = xr.map_blocks(lambda x: x.to_array(), map_ds) # to_array does not preserve name, so cannot use assert_identical - assert_equal(actual.compute(), map_ds.to_array().compute()) + assert_equal(actual, map_ds.to_array()) @pytest.mark.parametrize( @@ -1156,7 +1156,7 @@ def test_map_blocks_da_transformations(func, map_da): with raise_if_dask_computes(): actual = xr.map_blocks(func, map_da) - assert_identical(actual.compute(), func(map_da).compute()) + assert_identical(actual, func(map_da)) @pytest.mark.parametrize( @@ -1175,7 +1175,7 @@ def test_map_blocks_ds_transformations(func, map_ds): with raise_if_dask_computes(): actual = xr.map_blocks(func, map_ds) - assert_identical(actual.compute(), func(map_ds).compute()) + assert_identical(actual, func(map_ds)) @pytest.mark.parametrize("obj", [make_da(), make_ds()]) @@ -1188,7 +1188,7 @@ def func(obj): expected = xr.map_blocks(func, obj) actual = obj.map_blocks(func) - assert_identical(expected.compute(), actual.compute()) + assert_identical(expected, actual) def test_map_blocks_hlg_layers(): diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 4189c3b504a..786eb5007a6 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -752,12 +752,19 @@ def test_chunk(self): blocked = unblocked.chunk() assert blocked.chunks == ((3,), (4,)) + first_dask_name = blocked.data.name blocked = unblocked.chunk(chunks=((2, 1), (2, 2))) assert blocked.chunks == ((2, 1), (2, 2)) + assert blocked.data.name != first_dask_name blocked = unblocked.chunk(chunks=(3, 3)) assert blocked.chunks == ((3,), (3, 1)) + assert blocked.data.name != first_dask_name + + # name doesn't change when rechunking by same amount + # this fails if ReprObject doesn't have __dask_tokenize__ defined + assert unblocked.chunk(2).data.name == unblocked.chunk(2).data.name assert blocked.load().chunks is None diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 2220abbef31..c953f5d22e9 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -936,19 +936,35 @@ def test_chunk(self): expected_chunks = {"dim1": (8,), "dim2": (9,), "dim3": (10,)} assert reblocked.chunks == expected_chunks + def get_dask_names(ds): + return {k: v.data.name for k, v in ds.items()} + + orig_dask_names = get_dask_names(reblocked) + reblocked = data.chunk({"time": 5, "dim1": 5, "dim2": 5, "dim3": 5}) # time is not a dim in any of the data_vars, so it # doesn't get chunked expected_chunks = {"dim1": (5, 3), "dim2": (5, 4), "dim3": (5, 5)} assert reblocked.chunks == expected_chunks + # make sure dask names change when rechunking by different amounts + # regression test for GH3350 + new_dask_names = get_dask_names(reblocked) + for k, v in new_dask_names.items(): + assert v != orig_dask_names[k] + reblocked = data.chunk(expected_chunks) assert reblocked.chunks == expected_chunks # reblock on already blocked data + orig_dask_names = get_dask_names(reblocked) reblocked = reblocked.chunk(expected_chunks) + new_dask_names = get_dask_names(reblocked) assert reblocked.chunks == expected_chunks assert_identical(reblocked, data) + # recuhnking with same chunk sizes should not change names + for k, v in new_dask_names.items(): + assert v == orig_dask_names[k] with raises_regex(ValueError, "some chunks"): data.chunk({"foo": 10}) From e8dbe6ee8cf3d94af0a35e92f7d77b08dc4b95df Mon Sep 17 00:00:00 2001 From: Riley Brady Date: Fri, 10 Jan 2020 09:48:31 -0700 Subject: [PATCH 06/23] Add map_blocks example to whats-new (#3682) --- doc/whats-new.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index e1c4e3dd9ac..93df2b6569a 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -89,6 +89,8 @@ Documentation - Added examples for :py:meth:`DataArray.quantile`, :py:meth:`Dataset.quantile` and ``GroupBy.quantile``. (:pull:`3576`) By `Justus Magin `_. +- Added example for :py:func:`~xarray.map_blocks`. (:pull:`3667`) + By `Riley X. Brady `_. Internal Changes ~~~~~~~~~~~~~~~~ From ff75081304eb2e2784dcb229cc48a532da557896 Mon Sep 17 00:00:00 2001 From: rpgoldman Date: Fri, 10 Jan 2020 14:02:22 -0600 Subject: [PATCH 07/23] How do I add a new variable to dataset. (#3679) * How do I add a new variable to dataset. * Update doc/howdoi.rst Improvement from dcherian Co-Authored-By: Deepak Cherian * Add cross-reference per suggestion. * Fix cross-reference. * Update doc/howdoi.rst Suggestion from keewis Co-Authored-By: keewis Co-authored-by: Deepak Cherian Co-authored-by: keewis --- doc/data-structures.rst | 2 ++ doc/howdoi.rst | 3 +++ 2 files changed, 5 insertions(+) diff --git a/doc/data-structures.rst b/doc/data-structures.rst index 504d820a234..70e34adabed 100644 --- a/doc/data-structures.rst +++ b/doc/data-structures.rst @@ -353,6 +353,8 @@ setting) variables and attributes: This is particularly useful in an exploratory context, because you can tab-complete these variable names with tools like IPython. +.. _dictionary_like_methods: + Dictionary like methods ~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/doc/howdoi.rst b/doc/howdoi.rst index 80266bd3b84..84c0c786027 100644 --- a/doc/howdoi.rst +++ b/doc/howdoi.rst @@ -11,6 +11,8 @@ How do I ... * - How do I... - Solution + * - add a DataArray to my dataset as a new variable + - ``my_dataset[varname] = my_dataArray`` or :py:meth:`Dataset.assign` (see also :ref:`dictionary_like_methods`) * - add variables from other datasets to my dataset - :py:meth:`Dataset.merge` * - add a new dimension and/or coordinate @@ -57,3 +59,4 @@ How do I ... - ``obj.dt.ceil``, ``obj.dt.floor``, ``obj.dt.round``. See :ref:`dt_accessor` for more. * - make a mask that is ``True`` where an object contains any of the values in a array - :py:meth:`Dataset.isin`, :py:meth:`DataArray.isin` + From 099c0901e29927935440692b1363ef08fe6d2e42 Mon Sep 17 00:00:00 2001 From: Julien Seguinot Date: Sat, 11 Jan 2020 16:22:55 +0100 Subject: [PATCH 08/23] Add option to choose mfdataset attributes source. (#3498) * Add 'master_file' kwarg in open_mfdataset, which can be a str or Path to a particular data file. --- doc/whats-new.rst | 3 +++ xarray/backends/api.py | 19 +++++++++++++++--- xarray/tests/test_backends.py | 36 +++++++++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 3 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 93df2b6569a..eacf8433c0a 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -37,6 +37,9 @@ New Features - Added the ``count`` reduction method to both :py:class:`~core.rolling.DatasetCoarsen` and :py:class:`~core.rolling.DataArrayCoarsen` objects. (:pull:`3500`) By `Deepak Cherian `_ +- Add `attrs_file` option in :py:func:`~xarray.open_mfdataset` to choose the + source file for global attributes in a multi-file dataset (:issue:`2382`, + :pull:`3498`) by `Julien Seguinot _`. - :py:meth:`Dataset.swap_dims` and :py:meth:`DataArray.swap_dims` now allow swapping to dimension names that don't exist yet. (:pull:`3636`) By `Justus Magin `_. diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 23d09ba5e33..eea1fb9ddce 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -718,6 +718,7 @@ def open_mfdataset( autoclose=None, parallel=False, join="outer", + attrs_file=None, **kwargs, ): """Open multiple files as a single dataset. @@ -729,8 +730,8 @@ def open_mfdataset( ``combine_by_coords`` and ``combine_nested``. By default the old (now deprecated) ``auto_combine`` will be used, please specify either ``combine='by_coords'`` or ``combine='nested'`` in future. Requires dask to be installed. See documentation for - details on dask [1]_. Attributes from the first dataset file are used for the - combined dataset. + details on dask [1]_. Global attributes from the ``attrs_file`` are used + for the combined dataset. Parameters ---------- @@ -827,6 +828,10 @@ def open_mfdataset( - 'override': if indexes are of same size, rewrite indexes to be those of the first object with that dimension. Indexes for the same dimension must have the same size in all objects. + attrs_file : str or pathlib.Path, optional + Path of the file used to read global attributes from. + By default global attributes are read from the first file provided, + with wildcard matches sorted by filename. **kwargs : optional Additional arguments passed on to :py:func:`xarray.open_dataset`. @@ -961,7 +966,15 @@ def open_mfdataset( raise combined._file_obj = _MultiFileCloser(file_objs) - combined.attrs = datasets[0].attrs + + # read global attributes from the attrs_file or from the first dataset + if attrs_file is not None: + if isinstance(attrs_file, Path): + attrs_file = str(attrs_file) + combined.attrs = datasets[paths.index(attrs_file)].attrs + else: + combined.attrs = datasets[0].attrs + return combined diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index a23527bd49a..4fccdf2dd6c 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -2832,6 +2832,42 @@ def test_attrs_mfdataset(self): with raises_regex(AttributeError, "no attribute"): actual.test2 + def test_open_mfdataset_attrs_file(self): + original = Dataset({"foo": ("x", np.random.randn(10))}) + with create_tmp_files(2) as (tmp1, tmp2): + ds1 = original.isel(x=slice(5)) + ds2 = original.isel(x=slice(5, 10)) + ds1.attrs["test1"] = "foo" + ds2.attrs["test2"] = "bar" + ds1.to_netcdf(tmp1) + ds2.to_netcdf(tmp2) + with open_mfdataset( + [tmp1, tmp2], concat_dim="x", combine="nested", attrs_file=tmp2 + ) as actual: + # attributes are inherited from the master file + assert actual.attrs["test2"] == ds2.attrs["test2"] + # attributes from ds1 are not retained, e.g., + assert "test1" not in actual.attrs + + def test_open_mfdataset_attrs_file_path(self): + original = Dataset({"foo": ("x", np.random.randn(10))}) + with create_tmp_files(2) as (tmp1, tmp2): + tmp1 = Path(tmp1) + tmp2 = Path(tmp2) + ds1 = original.isel(x=slice(5)) + ds2 = original.isel(x=slice(5, 10)) + ds1.attrs["test1"] = "foo" + ds2.attrs["test2"] = "bar" + ds1.to_netcdf(tmp1) + ds2.to_netcdf(tmp2) + with open_mfdataset( + [tmp1, tmp2], concat_dim="x", combine="nested", attrs_file=tmp2 + ) as actual: + # attributes are inherited from the master file + assert actual.attrs["test2"] == ds2.attrs["test2"] + # attributes from ds1 are not retained, e.g., + assert "test1" not in actual.attrs + def test_open_mfdataset_auto_combine(self): original = Dataset({"foo": ("x", np.random.randn(10)), "x": np.arange(10)}) with create_tmp_file() as tmp1: From 40fab1b303d40c37d7307342cebbceb1ea8cc608 Mon Sep 17 00:00:00 2001 From: David Caron Date: Sat, 11 Jan 2020 18:15:44 -0500 Subject: [PATCH 09/23] fix docstring for combine_first: returns a Dataset (#3683) --- xarray/core/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 129f0c0f7a7..c314eb41458 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -4152,7 +4152,7 @@ def combine_first(self, other: "Dataset") -> "Dataset": Returns ------- - DataArray + Dataset """ out = ops.fillna(self, other, join="outer", dataset_join="outer") return out From 1689db493f10262555196f658c52e370aacb4a33 Mon Sep 17 00:00:00 2001 From: Tom Nicholas <35968931+TomNicholas@users.noreply.github.com> Date: Sun, 12 Jan 2020 13:04:01 +0000 Subject: [PATCH 10/23] ds.merge(da) bugfix (#3677) * Added mwe as test * Cast to Dataset * Updated what's new * black formatted * Use assert_identical --- doc/whats-new.rst | 2 ++ xarray/core/dataset.py | 1 + xarray/tests/test_merge.py | 7 +++++++ 3 files changed, 10 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index eacf8433c0a..6eeb55c23cb 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -74,6 +74,8 @@ Bug fixes By `Justus Magin `_. - :py:meth:`Dataset.rename`, :py:meth:`DataArray.rename` now check for conflicts with MultiIndex level names. +- :py:meth:`Dataset.merge` no longer fails when passed a `DataArray` instead of a `Dataset` object. + By `Tom Nicholas `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index c314eb41458..6ecd9b59f8e 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -3607,6 +3607,7 @@ def merge( If any variables conflict (see ``compat``). """ _check_inplace(inplace) + other = other.to_dataset() if isinstance(other, xr.DataArray) else other merge_result = dataset_merge_method( self, other, diff --git a/xarray/tests/test_merge.py b/xarray/tests/test_merge.py index c1e6c7a5ce8..6c8f3f65657 100644 --- a/xarray/tests/test_merge.py +++ b/xarray/tests/test_merge.py @@ -3,6 +3,7 @@ import xarray as xr from xarray.core import dtypes, merge +from xarray.testing import assert_identical from . import raises_regex from .test_dataset import create_test_data @@ -253,3 +254,9 @@ def test_merge_no_conflicts(self): with pytest.raises(xr.MergeError): ds3 = xr.Dataset({"a": ("y", [2, 3]), "y": [1, 2]}) ds1.merge(ds3, compat="no_conflicts") + + def test_merge_dataarray(self): + ds = xr.Dataset({"a": 0}) + da = xr.DataArray(data=1, name="b") + + assert_identical(ds.merge(da), xr.merge([ds, da])) From 59d3ba5e938bafb4a1981c1a56d42aa31041df0a Mon Sep 17 00:00:00 2001 From: Spencer Clark Date: Mon, 13 Jan 2020 11:31:37 -0500 Subject: [PATCH 11/23] Explicitly convert result of pd.to_datetime to a timezone-naive type (#3688) --- xarray/tests/test_coding_times.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index d012fb36c35..00c34940ce4 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -451,7 +451,7 @@ def test_cf_datetime_nan(num_dates, units, expected_list): warnings.filterwarnings("ignore", "All-NaN") actual = coding.times.decode_cf_datetime(num_dates, units) # use pandas because numpy will deprecate timezone-aware conversions - expected = pd.to_datetime(expected_list) + expected = pd.to_datetime(expected_list).to_numpy(dtype="datetime64[ns]") assert_array_equal(expected, actual) From 8a650a11d1f859a88cc91b8815c16597203892aa Mon Sep 17 00:00:00 2001 From: Tom Nicholas <35968931+TomNicholas@users.noreply.github.com> Date: Mon, 13 Jan 2020 16:33:04 +0000 Subject: [PATCH 12/23] Fix mypy type checking tests failure in ds.merge (#3690) * Added DataArray as valid type input to ds.merge method * Fix import error by specifying type as string --- xarray/core/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 6ecd9b59f8e..2d15ff586e4 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -3550,7 +3550,7 @@ def update(self, other: "CoercibleMapping", inplace: bool = None) -> "Dataset": def merge( self, - other: "CoercibleMapping", + other: Union["CoercibleMapping", "DataArray"], inplace: bool = None, overwrite_vars: Union[Hashable, Iterable[Hashable]] = frozenset(), compat: str = "no_conflicts", From 40423457928245d216f973530904df1c93110f6c Mon Sep 17 00:00:00 2001 From: Emmanuel Roux <15956441+fesaille@users.noreply.github.com> Date: Mon, 13 Jan 2020 21:32:17 +0100 Subject: [PATCH 13/23] Typo on DataSet/DataArray.to_dict documentation (#3692) --- xarray/core/dataarray.py | 2 +- xarray/core/dataset.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index cbd8d243385..15db6ed468e 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -2368,7 +2368,7 @@ def to_dict(self, data: bool = True) -> dict: naming conventions. Converts all variables and attributes to native Python objects. - Useful for coverting to json. To avoid datetime incompatibility + Useful for converting to json. To avoid datetime incompatibility use decode_times=False kwarg in xarrray.open_dataset. Parameters diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 2d15ff586e4..1aaad02b470 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -4667,7 +4667,7 @@ def to_dict(self, data=True): conventions. Converts all variables and attributes to native Python objects - Useful for coverting to json. To avoid datetime incompatibility + Useful for converting to json. To avoid datetime incompatibility use decode_times=False kwarg in xarrray.open_dataset. Parameters From e0fd48052dbda34ee35d2491e4fe856495c9621b Mon Sep 17 00:00:00 2001 From: keewis Date: Tue, 14 Jan 2020 17:13:23 +0100 Subject: [PATCH 14/23] allow passing any iterable to drop when dropping variables (#3693) * allow passing any iterable to drop when dropping variables * whats-new.rst * update whats-new.rst --- doc/whats-new.rst | 3 +++ xarray/core/dataset.py | 3 +-- xarray/tests/test_dataset.py | 4 ++++ 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 6eeb55c23cb..1ad344a208d 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -76,6 +76,9 @@ Bug fixes MultiIndex level names. - :py:meth:`Dataset.merge` no longer fails when passed a `DataArray` instead of a `Dataset` object. By `Tom Nicholas `_. +- Fix a regression in :py:meth:`Dataset.drop`: allow passing any + iterable when dropping variables (:issue:`3552`, :pull:`3693`) + By `Justus Magin `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 1aaad02b470..79f1030fabe 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -85,7 +85,6 @@ either_dict_or_kwargs, hashable, is_dict_like, - is_list_like, is_scalar, maybe_wrap_array, ) @@ -3690,7 +3689,7 @@ def drop(self, labels=None, dim=None, *, errors="raise", **labels_kwargs): raise ValueError("cannot specify dim and dict-like arguments.") labels = either_dict_or_kwargs(labels, labels_kwargs, "drop") - if dim is None and (is_list_like(labels) or is_scalar(labels)): + if dim is None and (is_scalar(labels) or isinstance(labels, Iterable)): warnings.warn( "dropping variables using `drop` will be deprecated; using drop_vars is encouraged.", PendingDeprecationWarning, diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index c953f5d22e9..f9eb37dbf2f 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -2167,6 +2167,10 @@ def test_drop_variables(self): actual = data.drop(["time", "not_found_here"], errors="ignore") assert_identical(expected, actual) + with pytest.warns(PendingDeprecationWarning): + actual = data.drop({"time", "not_found_here"}, errors="ignore") + assert_identical(expected, actual) + def test_drop_index_labels(self): data = Dataset({"A": (["x", "y"], np.random.randn(2, 3)), "x": ["a", "b"]}) From 99594051ef591f12b4b78a8b24136da46d0bf28f Mon Sep 17 00:00:00 2001 From: Spencer Clark Date: Wed, 15 Jan 2020 10:22:29 -0500 Subject: [PATCH 15/23] Use encoding['dtype'] over data.dtype when possible within CFMaskCoder.encode (#3652) * Use encoding['dtype'] over data.dtype when possible * Add what's new entry * Fix typo in what's new Co-authored-by: Deepak Cherian --- doc/whats-new.rst | 3 +++ xarray/coding/variables.py | 5 +++-- xarray/tests/test_coding.py | 36 +++++++++++++++++++++++++++--------- 3 files changed, 33 insertions(+), 11 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 1ad344a208d..f8d13098250 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -69,6 +69,9 @@ Bug fixes By `Tom Augspurger `_. - Ensure :py:meth:`Dataset.quantile`, :py:meth:`DataArray.quantile` issue the correct error when ``q`` is out of bounds (:issue:`3634`) by `Mathias Hauser `_. +- Fix regression in xarray 0.14.1 that prevented encoding times with certain + ``dtype``, ``_FillValue``, and ``missing_value`` encodings (:issue:`3624`). + By `Spencer Clark `_ - Raise an error when trying to use :py:meth:`Dataset.rename_dims` to rename to an existing name (:issue:`3438`, :pull:`3645`) By `Justus Magin `_. diff --git a/xarray/coding/variables.py b/xarray/coding/variables.py index 2b5f87ab0cd..28ead397461 100644 --- a/xarray/coding/variables.py +++ b/xarray/coding/variables.py @@ -148,6 +148,7 @@ class CFMaskCoder(VariableCoder): def encode(self, variable, name=None): dims, data, attrs, encoding = unpack_for_encoding(variable) + dtype = np.dtype(encoding.get("dtype", data.dtype)) fv = encoding.get("_FillValue") mv = encoding.get("missing_value") @@ -162,14 +163,14 @@ def encode(self, variable, name=None): if fv is not None: # Ensure _FillValue is cast to same dtype as data's - encoding["_FillValue"] = data.dtype.type(fv) + encoding["_FillValue"] = dtype.type(fv) fill_value = pop_to(encoding, attrs, "_FillValue", name=name) if not pd.isnull(fill_value): data = duck_array_ops.fillna(data, fill_value) if mv is not None: # Ensure missing_value is cast to same dtype as data's - encoding["missing_value"] = data.dtype.type(mv) + encoding["missing_value"] = dtype.type(mv) fill_value = pop_to(encoding, attrs, "missing_value", name=name) if not pd.isnull(fill_value) and fv is None: data = duck_array_ops.fillna(data, fill_value) diff --git a/xarray/tests/test_coding.py b/xarray/tests/test_coding.py index 3e0474e7b60..0f191049284 100644 --- a/xarray/tests/test_coding.py +++ b/xarray/tests/test_coding.py @@ -1,10 +1,12 @@ from contextlib import suppress import numpy as np +import pandas as pd import pytest import xarray as xr from xarray.coding import variables +from xarray.conventions import decode_cf_variable, encode_cf_variable from . import assert_equal, assert_identical, requires_dask @@ -20,20 +22,36 @@ def test_CFMaskCoder_decode(): assert_identical(expected, encoded) -def test_CFMaskCoder_encode_missing_fill_values_conflict(): - original = xr.Variable( - ("x",), - [0.0, -1.0, 1.0], - encoding={"_FillValue": np.float32(1e20), "missing_value": np.float64(1e20)}, - ) - coder = variables.CFMaskCoder() - encoded = coder.encode(original) +encoding_with_dtype = { + "dtype": np.dtype("float64"), + "_FillValue": np.float32(1e20), + "missing_value": np.float64(1e20), +} +encoding_without_dtype = { + "_FillValue": np.float32(1e20), + "missing_value": np.float64(1e20), +} +CFMASKCODER_ENCODE_DTYPE_CONFLICT_TESTS = { + "numeric-with-dtype": ([0.0, -1.0, 1.0], encoding_with_dtype), + "numeric-without-dtype": ([0.0, -1.0, 1.0], encoding_without_dtype), + "times-with-dtype": (pd.date_range("2000", periods=3), encoding_with_dtype), +} + + +@pytest.mark.parametrize( + ("data", "encoding"), + CFMASKCODER_ENCODE_DTYPE_CONFLICT_TESTS.values(), + ids=list(CFMASKCODER_ENCODE_DTYPE_CONFLICT_TESTS.keys()), +) +def test_CFMaskCoder_encode_missing_fill_values_conflict(data, encoding): + original = xr.Variable(("x",), data, encoding=encoding) + encoded = encode_cf_variable(original) assert encoded.dtype == encoded.attrs["missing_value"].dtype assert encoded.dtype == encoded.attrs["_FillValue"].dtype with pytest.warns(variables.SerializationWarning): - roundtripped = coder.decode(coder.encode(original)) + roundtripped = decode_cf_variable("foo", encoded) assert_identical(roundtripped, original) From f8386bed945ec5e586baf906548903da93149258 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 15 Jan 2020 08:25:56 -0700 Subject: [PATCH 16/23] Add an example notebook using apply_ufunc to vectorize 1D functions (#3629) * apply_ufunc vectorize 1D function example * Small udpates. * Review + strip output. * Minor updates. --- ci/requirements/doc.yml | 1 + doc/examples.rst | 7 + doc/examples/apply_ufunc_vectorize_1d.ipynb | 736 ++++++++++++++++++++ doc/whats-new.rst | 3 + 4 files changed, 747 insertions(+) create mode 100644 doc/examples/apply_ufunc_vectorize_1d.ipynb diff --git a/ci/requirements/doc.yml b/ci/requirements/doc.yml index a0c27a30b01..1b479ad4ca2 100644 --- a/ci/requirements/doc.yml +++ b/ci/requirements/doc.yml @@ -14,6 +14,7 @@ dependencies: - jupyter_client - nbsphinx - netcdf4 + - numba - numpy - numpydoc - pandas<0.25 # Hack around https://github.com/pydata/xarray/issues/3369 diff --git a/doc/examples.rst b/doc/examples.rst index ce56102cc9d..3067ca824be 100644 --- a/doc/examples.rst +++ b/doc/examples.rst @@ -10,3 +10,10 @@ Examples examples/visualization_gallery examples/ROMS_ocean_model examples/ERA5-GRIB-example + +Using apply_ufunc +------------------ +.. toctree:: + :maxdepth: 2 + + examples/apply_ufunc_vectorize_1d diff --git a/doc/examples/apply_ufunc_vectorize_1d.ipynb b/doc/examples/apply_ufunc_vectorize_1d.ipynb new file mode 100644 index 00000000000..6d18d48fdb5 --- /dev/null +++ b/doc/examples/apply_ufunc_vectorize_1d.ipynb @@ -0,0 +1,736 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Applying unvectorized functions with `apply_ufunc`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This example will illustrate how to conveniently apply an unvectorized function `func` to xarray objects using `apply_ufunc`. `func` expects 1D numpy arrays and returns a 1D numpy array. Our goal is to coveniently apply this function along a dimension of xarray objects that may or may not wrap dask arrays with a signature.\n", + "\n", + "We will illustrate this using `np.interp`: \n", + "\n", + " Signature: np.interp(x, xp, fp, left=None, right=None, period=None)\n", + " Docstring:\n", + " One-dimensional linear interpolation.\n", + "\n", + " Returns the one-dimensional piecewise linear interpolant to a function\n", + " with given discrete data points (`xp`, `fp`), evaluated at `x`.\n", + "\n", + "and write an `xr_interp` function with signature\n", + "\n", + " xr_interp(xarray_object, dimension_name, new_coordinate_to_interpolate_to)\n", + "\n", + "### Load data\n", + "\n", + "First lets load an example dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:45:51.659160Z", + "start_time": "2020-01-15T14:45:50.528742Z" + } + }, + "outputs": [], + "source": [ + "import xarray as xr\n", + "import numpy as np\n", + "\n", + "xr.set_options(display_style=\"html\") # fancy HTML repr\n", + "\n", + "air = (\n", + " xr.tutorial.load_dataset(\"air_temperature\")\n", + " .air.sortby(\"lat\") # np.interp needs coordinate in ascending order\n", + " .isel(time=slice(4), lon=slice(3))\n", + ") # choose a small subset for convenience\n", + "air" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The function we will apply is `np.interp` which expects 1D numpy arrays. This functionality is already implemented in xarray so we use that capability to make sure we are not making mistakes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:45:55.431708Z", + "start_time": "2020-01-15T14:45:55.104701Z" + } + }, + "outputs": [], + "source": [ + "newlat = np.linspace(15, 75, 100)\n", + "air.interp(lat=newlat)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's define a function that works with one vector of data along `lat` at a time." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:45:57.889496Z", + "start_time": "2020-01-15T14:45:57.792269Z" + } + }, + "outputs": [], + "source": [ + "def interp1d_np(data, x, xi):\n", + " return np.interp(xi, x, data)\n", + "\n", + "\n", + "interped = interp1d_np(air.isel(time=0, lon=0), air.lat, newlat)\n", + "expected = air.interp(lat=newlat)\n", + "\n", + "# no errors are raised if values are equal to within floating point precision\n", + "np.testing.assert_allclose(expected.isel(time=0, lon=0).values, interped)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### No errors are raised so our interpolation is working.\n", + "\n", + "This function consumes and returns numpy arrays, which means we need to do a lot of work to convert the result back to an xarray object with meaningful metadata. This is where `apply_ufunc` is very useful.\n", + "\n", + "### `apply_ufunc`\n", + "\n", + " Apply a vectorized function for unlabeled arrays on xarray objects.\n", + "\n", + " The function will be mapped over the data variable(s) of the input arguments using \n", + " xarray’s standard rules for labeled computation, including alignment, broadcasting, \n", + " looping over GroupBy/Dataset variables, and merging of coordinates.\n", + " \n", + "`apply_ufunc` has many capabilities but for simplicity this example will focus on the common task of vectorizing 1D functions over nD xarray objects. We will iteratively build up the right set of arguments to `apply_ufunc` and read through many error messages in doing so." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:45:59.768626Z", + "start_time": "2020-01-15T14:45:59.543808Z" + } + }, + "outputs": [], + "source": [ + "xr.apply_ufunc(\n", + " interp1d_np, # first the function\n", + " air.isel(time=0, lon=0), # now arguments in the order expected by 'interp1_np'\n", + " air.lat,\n", + " newlat,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`apply_ufunc` needs to know a lot of information about what our function does so that it can reconstruct the outputs. In this case, the size of dimension lat has changed and we need to explicitly specify that this will happen. xarray helpfully tells us that we need to specify the kwarg `exclude_dims`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### `exclude_dims`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "```\n", + "exclude_dims : set, optional\n", + " Core dimensions on the inputs to exclude from alignment and\n", + " broadcasting entirely. Any input coordinates along these dimensions\n", + " will be dropped. Each excluded dimension must also appear in\n", + " ``input_core_dims`` for at least one argument. Only dimensions listed\n", + " here are allowed to change size between input and output objects.\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:46:02.187012Z", + "start_time": "2020-01-15T14:46:02.105563Z" + } + }, + "outputs": [], + "source": [ + "xr.apply_ufunc(\n", + " interp1d_np, # first the function\n", + " air.isel(time=0, lon=0), # now arguments in the order expected by 'interp1_np'\n", + " air.lat,\n", + " newlat,\n", + " exclude_dims=set((\"lat\",)), # dimensions allowed to change size. Must be set!\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Core dimensions\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Core dimensions are central to using `apply_ufunc`. In our case, our function expects to receive a 1D vector along `lat` — this is the dimension that is \"core\" to the function's functionality. Multiple core dimensions are possible. `apply_ufunc` needs to know which dimensions of each variable are core dimensions.\n", + "\n", + " input_core_dims : Sequence[Sequence], optional\n", + " List of the same length as ``args`` giving the list of core dimensions\n", + " on each input argument that should not be broadcast. By default, we\n", + " assume there are no core dimensions on any input arguments.\n", + "\n", + " For example, ``input_core_dims=[[], ['time']]`` indicates that all\n", + " dimensions on the first argument and all dimensions other than 'time'\n", + " on the second argument should be broadcast.\n", + "\n", + " Core dimensions are automatically moved to the last axes of input\n", + " variables before applying ``func``, which facilitates using NumPy style\n", + " generalized ufuncs [2]_.\n", + " \n", + " output_core_dims : List[tuple], optional\n", + " List of the same length as the number of output arguments from\n", + " ``func``, giving the list of core dimensions on each output that were\n", + " not broadcast on the inputs. By default, we assume that ``func``\n", + " outputs exactly one array, with axes corresponding to each broadcast\n", + " dimension.\n", + "\n", + " Core dimensions are assumed to appear as the last dimensions of each\n", + " output in the provided order.\n", + " \n", + "Next we specify `\"lat\"` as `input_core_dims` on both `air` and `air.lat`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:46:05.031672Z", + "start_time": "2020-01-15T14:46:04.947588Z" + } + }, + "outputs": [], + "source": [ + "xr.apply_ufunc(\n", + " interp1d_np, # first the function\n", + " air.isel(time=0, lon=0), # now arguments in the order expected by 'interp1_np'\n", + " air.lat,\n", + " newlat,\n", + " input_core_dims=[[\"lat\"], [\"lat\"], []],\n", + " exclude_dims=set((\"lat\",)), # dimensions allowed to change size. Must be set!\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "xarray is telling us that it expected to receive back a numpy array with 0 dimensions but instead received an array with 1 dimension corresponding to `newlat`. We can fix this by specifying `output_core_dims`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:46:09.325218Z", + "start_time": "2020-01-15T14:46:09.303020Z" + } + }, + "outputs": [], + "source": [ + "xr.apply_ufunc(\n", + " interp1d_np, # first the function\n", + " air.isel(time=0, lon=0), # now arguments in the order expected by 'interp1_np'\n", + " air.lat,\n", + " newlat,\n", + " input_core_dims=[[\"lat\"], [\"lat\"], []], # list with one entry per arg\n", + " output_core_dims=[[\"lat\"]],\n", + " exclude_dims=set((\"lat\",)), # dimensions allowed to change size. Must be set!\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally we get some output! Let's check that this is right\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:46:11.295440Z", + "start_time": "2020-01-15T14:46:11.226553Z" + } + }, + "outputs": [], + "source": [ + "interped = xr.apply_ufunc(\n", + " interp1d_np, # first the function\n", + " air.isel(time=0, lon=0), # now arguments in the order expected by 'interp1_np'\n", + " air.lat,\n", + " newlat,\n", + " input_core_dims=[[\"lat\"], [\"lat\"], []], # list with one entry per arg\n", + " output_core_dims=[[\"lat\"]],\n", + " exclude_dims=set((\"lat\",)), # dimensions allowed to change size. Must be set!\n", + ")\n", + "interped[\"lat\"] = newlat # need to add this manually\n", + "xr.testing.assert_allclose(expected.isel(time=0, lon=0), interped)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "No errors are raised so it is right!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Vectorization with `np.vectorize`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now our function currently only works on one vector of data which is not so useful given our 3D dataset.\n", + "Let's try passing the whole dataset. We add a `print` statement so we can see what our function receives." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:46:13.808646Z", + "start_time": "2020-01-15T14:46:13.680098Z" + } + }, + "outputs": [], + "source": [ + "def interp1d_np(data, x, xi):\n", + " print(f\"data: {data.shape} | x: {x.shape} | xi: {xi.shape}\")\n", + " return np.interp(xi, x, data)\n", + "\n", + "\n", + "interped = xr.apply_ufunc(\n", + " interp1d_np, # first the function\n", + " air.isel(\n", + " lon=slice(3), time=slice(4)\n", + " ), # now arguments in the order expected by 'interp1_np'\n", + " air.lat,\n", + " newlat,\n", + " input_core_dims=[[\"lat\"], [\"lat\"], []], # list with one entry per arg\n", + " output_core_dims=[[\"lat\"]],\n", + " exclude_dims=set((\"lat\",)), # dimensions allowed to change size. Must be set!\n", + ")\n", + "interped[\"lat\"] = newlat # need to add this manually\n", + "xr.testing.assert_allclose(expected.isel(time=0, lon=0), interped)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "That's a hard-to-interpret error but our `print` call helpfully printed the shapes of the input data: \n", + "\n", + " data: (10, 53, 25) | x: (25,) | xi: (100,)\n", + "\n", + "We see that `air` has been passed as a 3D numpy array which is not what `np.interp` expects. Instead we want loop over all combinations of `lon` and `time`; and apply our function to each corresponding vector of data along `lat`.\n", + "`apply_ufunc` makes this easy by specifying `vectorize=True`:\n", + "\n", + " vectorize : bool, optional\n", + " If True, then assume ``func`` only takes arrays defined over core\n", + " dimensions as input and vectorize it automatically with\n", + " :py:func:`numpy.vectorize`. This option exists for convenience, but is\n", + " almost always slower than supplying a pre-vectorized function.\n", + " Using this option requires NumPy version 1.12 or newer.\n", + " \n", + "Also see the documentation for `np.vectorize`: https://docs.scipy.org/doc/numpy/reference/generated/numpy.vectorize.html. Most importantly\n", + "\n", + " The vectorize function is provided primarily for convenience, not for performance. \n", + " The implementation is essentially a for loop." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:46:26.633233Z", + "start_time": "2020-01-15T14:46:26.515209Z" + } + }, + "outputs": [], + "source": [ + "def interp1d_np(data, x, xi):\n", + " print(f\"data: {data.shape} | x: {x.shape} | xi: {xi.shape}\")\n", + " return np.interp(xi, x, data)\n", + "\n", + "\n", + "interped = xr.apply_ufunc(\n", + " interp1d_np, # first the function\n", + " air, # now arguments in the order expected by 'interp1_np'\n", + " air.lat, # as above\n", + " newlat, # as above\n", + " input_core_dims=[[\"lat\"], [\"lat\"], []], # list with one entry per arg\n", + " output_core_dims=[[\"lat\"]], # returned data has one dimension\n", + " exclude_dims=set((\"lat\",)), # dimensions allowed to change size. Must be set!\n", + " vectorize=True, # loop over non-core dims\n", + ")\n", + "interped[\"lat\"] = newlat # need to add this manually\n", + "xr.testing.assert_allclose(expected, interped)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This unfortunately is another cryptic error from numpy. \n", + "\n", + "Notice that `newlat` is not an xarray object. Let's add a dimension name `new_lat` and modify the call. Note this cannot be `lat` because xarray expects dimensions to be the same size (or broadcastable) among all inputs. `output_core_dims` needs to be modified appropriately. We'll manually rename `new_lat` back to `lat` for easy checking." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:46:30.026663Z", + "start_time": "2020-01-15T14:46:29.893267Z" + } + }, + "outputs": [], + "source": [ + "def interp1d_np(data, x, xi):\n", + " print(f\"data: {data.shape} | x: {x.shape} | xi: {xi.shape}\")\n", + " return np.interp(xi, x, data)\n", + "\n", + "\n", + "interped = xr.apply_ufunc(\n", + " interp1d_np, # first the function\n", + " air, # now arguments in the order expected by 'interp1_np'\n", + " air.lat, # as above\n", + " newlat, # as above\n", + " input_core_dims=[[\"lat\"], [\"lat\"], [\"new_lat\"]], # list with one entry per arg\n", + " output_core_dims=[[\"new_lat\"]], # returned data has one dimension\n", + " exclude_dims=set((\"lat\",)), # dimensions allowed to change size. Must be a set!\n", + " vectorize=True, # loop over non-core dims\n", + ")\n", + "interped = interped.rename({\"new_lat\": \"lat\"})\n", + "interped[\"lat\"] = newlat # need to add this manually\n", + "xr.testing.assert_allclose(\n", + " expected.transpose(*interped.dims), interped # order of dims is different\n", + ")\n", + "interped" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Notice that the printed input shapes are all 1D and correspond to one vector along the `lat` dimension.\n", + "\n", + "The result is now an xarray object with coordinate values copied over from `data`. This is why `apply_ufunc` is so convenient; it takes care of a lot of boilerplate necessary to apply functions that consume and produce numpy arrays to xarray objects.\n", + "\n", + "One final point: `lat` is now the *last* dimension in `interped`. This is a \"property\" of core dimensions: they are moved to the end before being sent to `interp1d_np` as was noted in the docstring for `input_core_dims`\n", + "\n", + " Core dimensions are automatically moved to the last axes of input\n", + " variables before applying ``func``, which facilitates using NumPy style\n", + " generalized ufuncs [2]_." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Parallelization with dask\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "So far our function can only handle numpy arrays. A real benefit of `apply_ufunc` is the ability to easily parallelize over dask chunks _when needed_. \n", + "\n", + "We want to apply this function in a vectorized fashion over each chunk of the dask array. This is possible using dask's `blockwise` or `map_blocks`. `apply_ufunc` wraps `blockwise` and asking it to map the function over chunks using `blockwise` is as simple as specifying `dask=\"parallelized\"`. With this level of flexibility we need to provide dask with some extra information: \n", + " 1. `output_dtypes`: dtypes of all returned objects, and \n", + " 2. `output_sizes`: lengths of any new dimensions. \n", + " \n", + "Here we need to specify `output_dtypes` since `apply_ufunc` can infer the size of the new dimension `new_lat` from the argument corresponding to the third element in `input_core_dims`. Here I choose the chunk sizes to illustrate that `np.vectorize` is still applied so that our function receives 1D vectors even though the blocks are 3D." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:48:42.469341Z", + "start_time": "2020-01-15T14:48:42.344209Z" + } + }, + "outputs": [], + "source": [ + "def interp1d_np(data, x, xi):\n", + " print(f\"data: {data.shape} | x: {x.shape} | xi: {xi.shape}\")\n", + " return np.interp(xi, x, data)\n", + "\n", + "\n", + "interped = xr.apply_ufunc(\n", + " interp1d_np, # first the function\n", + " air.chunk(\n", + " {\"time\": 2, \"lon\": 2}\n", + " ), # now arguments in the order expected by 'interp1_np'\n", + " air.lat, # as above\n", + " newlat, # as above\n", + " input_core_dims=[[\"lat\"], [\"lat\"], [\"new_lat\"]], # list with one entry per arg\n", + " output_core_dims=[[\"new_lat\"]], # returned data has one dimension\n", + " exclude_dims=set((\"lat\",)), # dimensions allowed to change size. Must be a set!\n", + " vectorize=True, # loop over non-core dims\n", + " dask=\"parallelized\",\n", + " output_dtypes=[air.dtype], # one per output\n", + ").rename({\"new_lat\": \"lat\"})\n", + "interped[\"lat\"] = newlat # need to add this manually\n", + "xr.testing.assert_allclose(expected.transpose(*interped.dims), interped)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Yay! our function is receiving 1D vectors, so we've successfully parallelized applying a 1D function over a block. If you have a distributed dashboard up, you should see computes happening as equality is checked.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### High performance vectorization: gufuncs, numba & guvectorize\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`np.vectorize` is a very convenient function but is unfortunately slow. It is only marginally faster than writing a for loop in Python and looping. A common way to get around this is to write a base interpolation function that can handle nD arrays in a compiled language like Fortran and then pass that to `apply_ufunc`.\n", + "\n", + "Another option is to use the numba package which provides a very convenient `guvectorize` decorator: https://numba.pydata.org/numba-doc/latest/user/vectorize.html#the-guvectorize-decorator\n", + "\n", + "Any decorated function gets compiled and will loop over any non-core dimension in parallel when necessary. We need to specify some extra information:\n", + "\n", + " 1. Our function cannot return a variable any more. Instead it must receive a variable (the last argument) whose contents the function will modify. So we change from `def interp1d_np(data, x, xi)` to `def interp1d_np_gufunc(data, x, xi, out)`. Our computed results must be assigned to `out`. All values of `out` must be assigned explicitly.\n", + " \n", + " 2. `guvectorize` needs to know the dtypes of the input and output. This is specified in string form as the first argument. Each element of the tuple corresponds to each argument of the function. In this case, we specify `float64` for all inputs and outputs: `\"(float64[:], float64[:], float64[:], float64[:])\"` corresponding to `data, x, xi, out`\n", + " \n", + " 3. Now we need to tell numba the size of the dimensions the function takes as inputs and returns as output i.e. core dimensions. This is done in symbolic form i.e. `data` and `x` are vectors of the same length, say `n`; `xi` and the output `out` have a different length, say `m`. So the second argument is (again as a string)\n", + " `\"(n), (n), (m) -> (m).\"` corresponding again to `data, x, xi, out`\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:48:45.267633Z", + "start_time": "2020-01-15T14:48:44.943939Z" + } + }, + "outputs": [], + "source": [ + "from numba import float64, guvectorize\n", + "\n", + "\n", + "@guvectorize(\"(float64[:], float64[:], float64[:], float64[:])\", \"(n), (n), (m) -> (m)\")\n", + "def interp1d_np_gufunc(data, x, xi, out):\n", + " # numba doesn't really like this.\n", + " # seem to support fstrings so do it the old way\n", + " print(\n", + " \"data: \" + str(data.shape) + \" | x:\" + str(x.shape) + \" | xi: \" + str(xi.shape)\n", + " )\n", + " out[:] = np.interp(xi, x, data)\n", + " # gufuncs don't return data\n", + " # instead you assign to a the last arg\n", + " # return np.interp(xi, x, data)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The warnings are about object-mode compilation relating to the `print` statement. This means we don't get much speed up: https://numba.pydata.org/numba-doc/latest/user/performance-tips.html#no-python-mode-vs-object-mode. We'll keep the `print` statement temporarily to make sure that `guvectorize` acts like we want it to." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:48:54.755405Z", + "start_time": "2020-01-15T14:48:54.634724Z" + } + }, + "outputs": [], + "source": [ + "interped = xr.apply_ufunc(\n", + " interp1d_np_gufunc, # first the function\n", + " air.chunk(\n", + " {\"time\": 2, \"lon\": 2}\n", + " ), # now arguments in the order expected by 'interp1_np'\n", + " air.lat, # as above\n", + " newlat, # as above\n", + " input_core_dims=[[\"lat\"], [\"lat\"], [\"new_lat\"]], # list with one entry per arg\n", + " output_core_dims=[[\"new_lat\"]], # returned data has one dimension\n", + " exclude_dims=set((\"lat\",)), # dimensions allowed to change size. Must be a set!\n", + " # vectorize=True, # not needed since numba takes care of vectorizing\n", + " dask=\"parallelized\",\n", + " output_dtypes=[air.dtype], # one per output\n", + ").rename({\"new_lat\": \"lat\"})\n", + "interped[\"lat\"] = newlat # need to add this manually\n", + "xr.testing.assert_allclose(expected.transpose(*interped.dims), interped)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Yay! Our function is receiving 1D vectors and is working automatically with dask arrays. Finally let's comment out the print line and wrap everything up in a nice reusable function" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:49:28.667528Z", + "start_time": "2020-01-15T14:49:28.103914Z" + } + }, + "outputs": [], + "source": [ + "from numba import float64, guvectorize\n", + "\n", + "\n", + "@guvectorize(\n", + " \"(float64[:], float64[:], float64[:], float64[:])\",\n", + " \"(n), (n), (m) -> (m)\",\n", + " nopython=True,\n", + ")\n", + "def interp1d_np_gufunc(data, x, xi, out):\n", + " out[:] = np.interp(xi, x, data)\n", + "\n", + "\n", + "def xr_interp(data, dim, newdim):\n", + "\n", + " interped = xr.apply_ufunc(\n", + " interp1d_np_gufunc, # first the function\n", + " data, # now arguments in the order expected by 'interp1_np'\n", + " data[dim], # as above\n", + " newdim, # as above\n", + " input_core_dims=[[dim], [dim], [\"__newdim__\"]], # list with one entry per arg\n", + " output_core_dims=[[\"__newdim__\"]], # returned data has one dimension\n", + " exclude_dims=set((dim,)), # dimensions allowed to change size. Must be a set!\n", + " # vectorize=True, # not needed since numba takes care of vectorizing\n", + " dask=\"parallelized\",\n", + " output_dtypes=[data.dtype], # one per output; could also be float or np.dtype(\"float64\")\n", + " ).rename({\"__newdim__\": dim})\n", + " interped[dim] = newdim # need to add this manually\n", + "\n", + " return interped\n", + "\n", + "\n", + "xr.testing.assert_allclose(\n", + " expected.transpose(*interped.dims),\n", + " xr_interp(air.chunk({\"time\": 2, \"lon\": 2}), \"lat\", newlat),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This technique is generalizable to any 1D function." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + }, + "nbsphinx": { + "allow_errors": true + }, + "org": null, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": false, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": true + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/doc/whats-new.rst b/doc/whats-new.rst index f8d13098250..17c679d27b5 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -100,6 +100,9 @@ Documentation - Added examples for :py:meth:`DataArray.quantile`, :py:meth:`Dataset.quantile` and ``GroupBy.quantile``. (:pull:`3576`) By `Justus Magin `_. +- Add new :py:func:`apply_ufunc` example notebook demonstrating vectorization of a 1D + function using dask and numba. + By `Deepak Cherian `_. - Added example for :py:func:`~xarray.map_blocks`. (:pull:`3667`) By `Riley X. Brady `_. From 3955c37959a81d6e83c6253962d374ee2f5019b6 Mon Sep 17 00:00:00 2001 From: keewis Date: Wed, 15 Jan 2020 17:53:01 +0100 Subject: [PATCH 17/23] Tests for variables with units (#3654) * add tests for variable by inheriting from VariableSubclassobjects * make sure the utility functions work with variables * add additional tests for aggregation and numpy methods * don't assume variables have a name attribute * properly merge the used arguments * properly attach to non-xarray objects * add a test function for searchsorted and item * add tests for missing value detection methods * add tests for comparisons * don't try to check missing value behaviour for int arrays * xfail the compatible unit tests * use an additional check since identical is not sufficient right now * check for compatibility by comparing the dimensionality of units * add tests for fillna * add tests for isel * add tests for 1d math * add tests for broadcast_equals * add tests for masking * add tests for squeeze * add tests for coarsen, quantile, roll, round and shift * add tests for searchsorted * add tests for rolling_window * add tests for transpose * add tests for stack and unstack * add tests for set_dims * add tests for concat, copy and no_conflicts * add tests for reduce * add tests for pad_with_fill_value * all and any have been implemented by pint * remove a unnecessary call to np.array * mark the unrelated failures as xfail * remove a debug print * document the need for force_ndarray * make is_compatible a little bit more robust * clean up is_compatible * add docstrings explaining the use of the method and function classes * put the unit comparisons into a function * actually squeeze the dimensions separately --- xarray/tests/test_units.py | 854 ++++++++++++++++++++++++++++++++++++- 1 file changed, 844 insertions(+), 10 deletions(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index f8a8a259c1f..2cb1550c088 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -7,12 +7,15 @@ import xarray as xr from xarray.core import formatting from xarray.core.npcompat import IS_NEP18_ACTIVE +from .test_variable import VariableSubclassobjects pint = pytest.importorskip("pint") DimensionalityError = pint.errors.DimensionalityError -unit_registry = pint.UnitRegistry() +# make sure scalars are converted to 0d arrays so quantities can +# always be treated like ndarrays +unit_registry = pint.UnitRegistry(force_ndarray=True) Quantity = unit_registry.Quantity pytestmark = [ @@ -28,6 +31,25 @@ ] +def is_compatible(unit1, unit2): + def dimensionality(obj): + if isinstance(obj, (unit_registry.Quantity, unit_registry.Unit)): + unit_like = obj + else: + unit_like = unit_registry.dimensionless + + return unit_like.dimensionality + + return dimensionality(unit1) == dimensionality(unit2) + + +def compatible_mappings(first, second): + return { + key: is_compatible(unit1, unit2) + for key, (unit1, unit2) in merge_mappings(first, second) + } + + def array_extract_units(obj): if isinstance(obj, (xr.Variable, xr.DataArray, xr.Dataset)): obj = obj.data @@ -113,6 +135,10 @@ def extract_units(obj): } units = {**vars_units, **coords_units} + elif isinstance(obj, xr.Variable): + vars_units = {None: array_extract_units(obj.data)} + + units = {**vars_units} elif isinstance(obj, Quantity): vars_units = {None: array_extract_units(obj)} @@ -148,6 +174,9 @@ def strip_units(obj): new_obj = xr.DataArray( name=strip_units(obj.name), data=data, coords=coords, dims=obj.dims ) + elif isinstance(obj, xr.Variable): + data = array_strip_units(obj.data) + new_obj = obj.copy(data=data) elif isinstance(obj, unit_registry.Quantity): new_obj = obj.magnitude elif isinstance(obj, (list, tuple)): @@ -159,8 +188,9 @@ def strip_units(obj): def attach_units(obj, units): - if not isinstance(obj, (xr.DataArray, xr.Dataset)): - return array_attach_units(obj, units.get("data", 1)) + if not isinstance(obj, (xr.DataArray, xr.Dataset, xr.Variable)): + units = units.get("data", None) or units.get(None, None) or 1 + return array_attach_units(obj, units) if isinstance(obj, xr.Dataset): data_vars = { @@ -172,7 +202,7 @@ def attach_units(obj, units): } new_obj = xr.Dataset(data_vars=data_vars, coords=coords, attrs=obj.attrs) - else: + elif isinstance(obj, xr.DataArray): # try the array name, "data" and None, then fall back to dimensionless data_units = ( units.get(obj.name, None) @@ -198,6 +228,11 @@ def attach_units(obj, units): new_obj = xr.DataArray( name=obj.name, data=data, coords=coords, attrs=attrs, dims=dims ) + else: + data_units = units.get("data", None) or units.get(None, None) or 1 + + data = array_attach_units(obj.data, data_units) + new_obj = obj.copy(data=data) return new_obj @@ -243,6 +278,10 @@ def convert_units(obj, to): return new_obj +def assert_units_equal(a, b): + assert extract_units(a) == extract_units(b) + + def assert_equal_with_units(a, b): # works like xr.testing.assert_equal, but also explicitly checks units # so, it is more like assert_identical @@ -282,7 +321,27 @@ def dtype(request): return request.param +def merge_mappings(*mappings): + for key in set(mappings[0]).intersection(*mappings[1:]): + yield key, tuple(m[key] for m in mappings) + + +def merge_args(default_args, new_args): + from itertools import zip_longest + + fill_value = object() + return [ + second if second is not fill_value else first + for first, second in zip_longest(default_args, new_args, fillvalue=fill_value) + ] + + class method: + """ wrapper class to help with passing methods via parametrize + + This is works a bit similar to using `partial(Class.method, arg, kwarg)` + """ + def __init__(self, name, *args, **kwargs): self.name = name self.args = args @@ -292,7 +351,7 @@ def __call__(self, obj, *args, **kwargs): from collections.abc import Callable from functools import partial - all_args = list(self.args) + list(args) + all_args = merge_args(self.args, args) all_kwargs = {**self.kwargs, **kwargs} func = getattr(obj, self.name, None) @@ -301,7 +360,7 @@ def __call__(self, obj, *args, **kwargs): if not isinstance(obj, (xr.Variable, xr.DataArray, xr.Dataset)): numpy_func = getattr(np, self.name) func = partial(numpy_func, obj) - # remove typical xr args like "dim" + # remove typical xarray args like "dim" exclude_kwargs = ("dim", "dims") all_kwargs = { key: value @@ -318,12 +377,21 @@ def __repr__(self): class function: - def __init__(self, name_or_function, *args, **kwargs): + """ wrapper class for numpy functions + + Same as method, but the name is used for referencing numpy functions + """ + + def __init__(self, name_or_function, *args, function_label=None, **kwargs): if callable(name_or_function): - self.name = name_or_function.__name__ + self.name = ( + function_label + if function_label is not None + else name_or_function.__name__ + ) self.func = name_or_function else: - self.name = name_or_function + self.name = name_or_function if function_label is None else function_label self.func = getattr(np, name_or_function) if self.func is None: raise AttributeError( @@ -334,7 +402,7 @@ def __init__(self, name_or_function, *args, **kwargs): self.kwargs = kwargs def __call__(self, *args, **kwargs): - all_args = list(self.args) + list(args) + all_args = merge_args(self.args, args) all_kwargs = {**self.kwargs, **kwargs} return self.func(*all_args, **all_kwargs) @@ -343,6 +411,7 @@ def __repr__(self): return f"function_{self.name}" +@pytest.mark.xfail(reason="test bug: apply_ufunc should not be called that way") def test_apply_ufunc_dataarray(dtype): func = function( xr.apply_ufunc, np.mean, input_core_dims=[["x"]], kwargs={"axis": -1} @@ -358,6 +427,7 @@ def test_apply_ufunc_dataarray(dtype): assert_equal_with_units(expected, actual) +@pytest.mark.xfail(reason="test bug: apply_ufunc should not be called that way") def test_apply_ufunc_dataset(dtype): func = function( xr.apply_ufunc, np.mean, input_core_dims=[["x"]], kwargs={"axis": -1} @@ -1245,6 +1315,769 @@ def test_dot_dataarray(dtype): assert_equal_with_units(expected, actual) +def delete_attrs(*to_delete): + def wrapper(cls): + for item in to_delete: + setattr(cls, item, None) + + return cls + + return wrapper + + +@delete_attrs( + "test_getitem_with_mask", + "test_getitem_with_mask_nd_indexer", + "test_index_0d_string", + "test_index_0d_datetime", + "test_index_0d_timedelta64", + "test_0d_time_data", + "test_datetime64_conversion", + "test_timedelta64_conversion", + "test_pandas_period_index", + "test_1d_math", + "test_1d_reduce", + "test_array_interface", + "test___array__", + "test_copy_index", + "test_concat_number_strings", + "test_concat_fixed_len_str", + "test_concat_mixed_dtypes", + "test_pandas_datetime64_with_tz", + "test_pandas_data", + "test_multiindex", +) +class TestVariable(VariableSubclassobjects): + @staticmethod + def cls(dims, data, *args, **kwargs): + return xr.Variable( + dims, unit_registry.Quantity(data, unit_registry.m), *args, **kwargs + ) + + @pytest.mark.parametrize( + "func", + ( + method("all"), + method("any"), + method("argmax"), + method("argmin"), + method("argsort"), + method("cumprod"), + method("cumsum"), + method("max"), + method("mean"), + method("median"), + method("min"), + pytest.param( + method("prod"), + marks=pytest.mark.xfail(reason="not implemented by pint"), + ), + method("std"), + method("sum"), + method("var"), + ), + ids=repr, + ) + def test_aggregation(self, func, dtype): + array = np.linspace(0, 1, 10).astype(dtype) * ( + unit_registry.m if func.name != "cumprod" else unit_registry.dimensionless + ) + variable = xr.Variable("x", array) + + units = extract_units(func(array)) + expected = attach_units(func(strip_units(variable)), units) + actual = func(variable) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + @pytest.mark.parametrize( + "func", + ( + method("astype", np.float32), + method("conj"), + method("conjugate"), + method("clip", min=2, max=7), + ), + ids=repr, + ) + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_numpy_methods(self, func, unit, error, dtype): + array = np.linspace(0, 1, 10).astype(dtype) * unit_registry.m + variable = xr.Variable("x", array) + + args = [ + item * unit if isinstance(item, (int, float, list)) else item + for item in func.args + ] + kwargs = { + key: value * unit if isinstance(value, (int, float, list)) else value + for key, value in func.kwargs.items() + } + + if error is not None and func.name in ("searchsorted", "clip"): + with pytest.raises(error): + func(variable, *args, **kwargs) + + return + + converted_args = [ + strip_units(convert_units(item, {None: unit_registry.m})) for item in args + ] + converted_kwargs = { + key: strip_units(convert_units(value, {None: unit_registry.m})) + for key, value in kwargs.items() + } + + units = extract_units(func(array, *args, **kwargs)) + expected = attach_units( + func(strip_units(variable), *converted_args, **converted_kwargs), units + ) + actual = func(variable, *args, **kwargs) + + assert_units_equal(expected, actual) + xr.testing.assert_allclose(expected, actual) + + @pytest.mark.parametrize( + "func", (method("item", 5), method("searchsorted", 5)), ids=repr + ) + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_raw_numpy_methods(self, func, unit, error, dtype): + array = np.linspace(0, 1, 10).astype(dtype) * unit_registry.m + variable = xr.Variable("x", array) + + args = [ + item * unit + if isinstance(item, (int, float, list)) and func.name != "item" + else item + for item in func.args + ] + kwargs = { + key: value * unit + if isinstance(value, (int, float, list)) and func.name != "item" + else value + for key, value in func.kwargs.items() + } + + if error is not None and func.name != "item": + with pytest.raises(error): + func(variable, *args, **kwargs) + + return + + converted_args = [ + strip_units(convert_units(item, {None: unit_registry.m})) + if func.name != "item" + else item + for item in args + ] + converted_kwargs = { + key: strip_units(convert_units(value, {None: unit_registry.m})) + if func.name != "item" + else value + for key, value in kwargs.items() + } + + units = extract_units(func(array, *args, **kwargs)) + expected = attach_units( + func(strip_units(variable), *converted_args, **converted_kwargs), units + ) + actual = func(variable, *args, **kwargs) + + assert_units_equal(expected, actual) + np.testing.assert_allclose(expected, actual) + + @pytest.mark.parametrize( + "func", (method("isnull"), method("notnull"), method("count")), ids=repr + ) + def test_missing_value_detection(self, func): + array = ( + np.array( + [ + [1.4, 2.3, np.nan, 7.2], + [np.nan, 9.7, np.nan, np.nan], + [2.1, np.nan, np.nan, 4.6], + [9.9, np.nan, 7.2, 9.1], + ] + ) + * unit_registry.degK + ) + variable = xr.Variable(("x", "y"), array) + + expected = func(strip_units(variable)) + actual = func(variable) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param( + 1, + DimensionalityError, + id="no_unit", + marks=pytest.mark.xfail(reason="uses 0 as a replacement"), + ), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param( + unit_registry.cm, + None, + id="compatible_unit", + marks=pytest.mark.xfail(reason="converts to fill value's unit"), + ), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_missing_value_fillna(self, unit, error): + value = 0 + array = ( + np.array( + [ + [1.4, 2.3, np.nan, 7.2], + [np.nan, 9.7, np.nan, np.nan], + [2.1, np.nan, np.nan, 4.6], + [9.9, np.nan, 7.2, 9.1], + ] + ) + * unit_registry.m + ) + variable = xr.Variable(("x", "y"), array) + + fill_value = value * unit + + if error is not None: + with pytest.raises(error): + variable.fillna(value=fill_value) + + return + + expected = attach_units( + strip_units(variable).fillna( + value=fill_value.to(unit_registry.m).magnitude + ), + extract_units(variable), + ) + actual = variable.fillna(value=fill_value) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + @pytest.mark.parametrize( + "unit", + ( + pytest.param(1, id="no_unit"), + pytest.param(unit_registry.dimensionless, id="dimensionless"), + pytest.param(unit_registry.s, id="incompatible_unit"), + pytest.param( + unit_registry.cm, + id="compatible_unit", + marks=pytest.mark.xfail( + reason="checking for identical units does not work properly, yet" + ), + ), + pytest.param(unit_registry.m, id="identical_unit"), + ), + ) + @pytest.mark.parametrize( + "convert_data", + ( + pytest.param(False, id="no_conversion"), + pytest.param(True, id="with_conversion"), + ), + ) + @pytest.mark.parametrize("func", (method("equals"), method("identical")), ids=repr) + def test_comparisons(self, func, unit, convert_data, dtype): + array = np.linspace(0, 1, 9).astype(dtype) + quantity1 = array * unit_registry.m + variable = xr.Variable("x", quantity1) + + if convert_data and is_compatible(unit_registry.m, unit): + quantity2 = convert_units(array * unit_registry.m, {None: unit}) + else: + quantity2 = array * unit + other = xr.Variable("x", quantity2) + + expected = func( + strip_units(variable), + strip_units( + convert_units(other, extract_units(variable)) + if is_compatible(unit_registry.m, unit) + else other + ), + ) + if func.name == "identical": + expected &= extract_units(variable) == extract_units(other) + else: + expected &= all( + compatible_mappings( + extract_units(variable), extract_units(other) + ).values() + ) + + actual = func(variable, other) + + assert expected == actual + + @pytest.mark.parametrize( + "unit", + ( + pytest.param(1, id="no_unit"), + pytest.param(unit_registry.dimensionless, id="dimensionless"), + pytest.param(unit_registry.s, id="incompatible_unit"), + pytest.param(unit_registry.cm, id="compatible_unit"), + pytest.param(unit_registry.m, id="identical_unit"), + ), + ) + def test_broadcast_equals(self, unit, dtype): + base_unit = unit_registry.m + left_array = np.ones(shape=(2, 2), dtype=dtype) * base_unit + value = ( + (1 * base_unit).to(unit).magnitude if is_compatible(unit, base_unit) else 1 + ) + right_array = np.full(shape=(2,), fill_value=value, dtype=dtype) * unit + + left = xr.Variable(("x", "y"), left_array) + right = xr.Variable("x", right_array) + + units = { + **extract_units(left), + **({} if is_compatible(unit, base_unit) else {None: None}), + } + expected = strip_units(left).broadcast_equals( + strip_units(convert_units(right, units)) + ) & is_compatible(unit, base_unit) + actual = left.broadcast_equals(right) + + assert expected == actual + + @pytest.mark.parametrize( + "indices", + ( + pytest.param(4, id="single index"), + pytest.param([5, 2, 9, 1], id="multiple indices"), + ), + ) + def test_isel(self, indices, dtype): + array = np.linspace(0, 5, 10).astype(dtype) * unit_registry.s + variable = xr.Variable("x", array) + + expected = attach_units( + strip_units(variable).isel(x=indices), extract_units(variable) + ) + actual = variable.isel(x=indices) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + @pytest.mark.parametrize( + "func", + ( + function(lambda x, *_: +x, function_label="unary_plus"), + function(lambda x, *_: -x, function_label="unary_minus"), + function(lambda x, *_: abs(x), function_label="absolute"), + function(lambda x, y: x + y, function_label="sum"), + function(lambda x, y: y + x, function_label="commutative_sum"), + function(lambda x, y: x * y, function_label="product"), + function(lambda x, y: y * x, function_label="commutative_product"), + ), + ids=repr, + ) + def test_1d_math(self, func, unit, error, dtype): + base_unit = unit_registry.m + array = np.arange(5).astype(dtype) * base_unit + variable = xr.Variable("x", array) + + values = np.ones(5) + y = values * unit + + if error is not None and func.name in ("sum", "commutative_sum"): + with pytest.raises(error): + func(variable, y) + + return + + units = extract_units(func(array, y)) + if all(compatible_mappings(units, extract_units(y)).values()): + converted_y = convert_units(y, units) + else: + converted_y = y + + if all(compatible_mappings(units, extract_units(variable)).values()): + converted_variable = convert_units(variable, units) + else: + converted_variable = variable + + expected = attach_units( + func(strip_units(converted_variable), strip_units(converted_y)), units + ) + actual = func(variable, y) + + assert_units_equal(expected, actual) + xr.testing.assert_allclose(expected, actual) + + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param( + unit_registry.cm, + None, + id="compatible_unit", + marks=pytest.mark.xfail( + reason="getitem_with_mask converts to the unit of other" + ), + ), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + @pytest.mark.parametrize( + "func", (method("where"), method("_getitem_with_mask")), ids=repr + ) + def test_masking(self, func, unit, error, dtype): + base_unit = unit_registry.m + array = np.linspace(0, 5, 10).astype(dtype) * base_unit + variable = xr.Variable("x", array) + cond = np.array([True, False] * 5) + + other = -1 * unit + + if error is not None: + with pytest.raises(error): + func(variable, cond, other) + + return + + expected = attach_units( + func( + strip_units(variable), + cond, + strip_units( + convert_units( + other, + {None: base_unit} + if is_compatible(base_unit, unit) + else {None: None}, + ) + ), + ), + extract_units(variable), + ) + actual = func(variable, cond, other) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + def test_squeeze(self, dtype): + shape = (2, 1, 3, 1, 1, 2) + names = list("abcdef") + array = np.ones(shape=shape) * unit_registry.m + variable = xr.Variable(names, array) + + expected = attach_units( + strip_units(variable).squeeze(), extract_units(variable) + ) + actual = variable.squeeze() + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + names = tuple(name for name, size in zip(names, shape) if shape == 1) + for name in names: + expected = attach_units( + strip_units(variable).squeeze(dim=name), extract_units(variable) + ) + actual = variable.squeeze(dim=name) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + @pytest.mark.parametrize( + "func", + ( + method("coarsen", windows={"y": 2}, func=np.mean), + method("quantile", q=[0.25, 0.75]), + pytest.param( + method("rank", dim="x"), + marks=pytest.mark.xfail(reason="rank not implemented for non-ndarray"), + ), + method("roll", {"x": 2}), + pytest.param( + method("rolling_window", "x", 3, "window"), + marks=pytest.mark.xfail(reason="converts to ndarray"), + ), + method("reduce", np.std, "x"), + method("round", 2), + pytest.param( + method("shift", {"x": -2}), + marks=pytest.mark.xfail( + reason="trying to concatenate ndarray to quantity" + ), + ), + method("transpose", "y", "x"), + ), + ids=repr, + ) + def test_computation(self, func, dtype): + base_unit = unit_registry.m + array = np.linspace(0, 5, 5 * 10).reshape(5, 10).astype(dtype) * base_unit + variable = xr.Variable(("x", "y"), array) + + expected = attach_units(func(strip_units(variable)), extract_units(variable)) + + actual = func(variable) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_searchsorted(self, unit, error, dtype): + base_unit = unit_registry.m + array = np.linspace(0, 5, 10).astype(dtype) * base_unit + variable = xr.Variable("x", array) + + value = 0 * unit + + if error is not None: + with pytest.raises(error): + variable.searchsorted(value) + + return + + expected = strip_units(variable).searchsorted( + strip_units(convert_units(value, {None: base_unit})) + ) + + actual = variable.searchsorted(value) + + assert_units_equal(expected, actual) + np.testing.assert_allclose(expected, actual) + + def test_stack(self, dtype): + array = np.linspace(0, 5, 3 * 10).reshape(3, 10).astype(dtype) * unit_registry.m + variable = xr.Variable(("x", "y"), array) + + expected = attach_units( + strip_units(variable).stack(z=("x", "y")), extract_units(variable) + ) + actual = variable.stack(z=("x", "y")) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + def test_unstack(self, dtype): + array = np.linspace(0, 5, 3 * 10).astype(dtype) * unit_registry.m + variable = xr.Variable("z", array) + + expected = attach_units( + strip_units(variable).unstack(z={"x": 3, "y": 10}), extract_units(variable) + ) + actual = variable.unstack(z={"x": 3, "y": 10}) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + @pytest.mark.xfail(reason="ignores units") + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_concat(self, unit, error, dtype): + array1 = ( + np.linspace(0, 5, 3 * 10).reshape(3, 10).astype(dtype) * unit_registry.m + ) + array2 = np.linspace(5, 10, 10 * 2).reshape(10, 2).astype(dtype) * unit + + variable = xr.Variable(("x", "y"), array1) + other = xr.Variable(("y", "z"), array2) + + if error is not None: + with pytest.raises(error): + variable.concat(other) + + return + + units = extract_units(variable) + expected = attach_units( + strip_units(variable).concat(strip_units(convert_units(other, units))), + units, + ) + actual = variable.concat(other) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + def test_set_dims(self, dtype): + array = np.linspace(0, 5, 3 * 10).reshape(3, 10).astype(dtype) * unit_registry.m + variable = xr.Variable(("x", "y"), array) + + dims = {"z": 6, "x": 3, "a": 1, "b": 4, "y": 10} + expected = attach_units( + strip_units(variable).set_dims(dims), extract_units(variable) + ) + actual = variable.set_dims(dims) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + def test_copy(self, dtype): + array = np.linspace(0, 5, 10).astype(dtype) * unit_registry.m + other = np.arange(10).astype(dtype) * unit_registry.s + + variable = xr.Variable("x", array) + expected = attach_units( + strip_units(variable).copy(data=strip_units(other)), extract_units(other) + ) + actual = variable.copy(data=other) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + @pytest.mark.parametrize( + "unit", + ( + pytest.param(1, id="no_unit"), + pytest.param(unit_registry.dimensionless, id="dimensionless"), + pytest.param(unit_registry.s, id="incompatible_unit"), + pytest.param(unit_registry.cm, id="compatible_unit"), + pytest.param(unit_registry.m, id="identical_unit"), + ), + ) + def test_no_conflicts(self, unit, dtype): + base_unit = unit_registry.m + array1 = ( + np.array( + [ + [6.3, 0.3, 0.45], + [np.nan, 0.3, 0.3], + [3.7, np.nan, 0.2], + [9.43, 0.3, 0.7], + ] + ) + * base_unit + ) + array2 = np.array([np.nan, 0.3, np.nan]) * unit + + variable = xr.Variable(("x", "y"), array1) + other = xr.Variable("y", array2) + + expected = strip_units(variable).no_conflicts( + strip_units( + convert_units( + other, {None: base_unit if is_compatible(base_unit, unit) else None} + ) + ) + ) & is_compatible(base_unit, unit) + actual = variable.no_conflicts(other) + + assert expected == actual + + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param( + 1, + DimensionalityError, + id="no_unit", + marks=pytest.mark.xfail( + reason="is not treated the same as dimensionless" + ), + ), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_pad_with_fill_value(self, unit, error, dtype): + array = np.linspace(0, 5, 3 * 10).reshape(3, 10).astype(dtype) * unit_registry.m + variable = xr.Variable(("x", "y"), array) + + fill_value = -100 * unit + + func = method("pad_with_fill_value", x=(2, 3), y=(1, 4)) + if error is not None: + with pytest.raises(error): + func(variable, fill_value=fill_value) + + return + + units = extract_units(variable) + expected = attach_units( + func( + strip_units(variable), + fill_value=strip_units(convert_units(fill_value, units)), + ), + units, + ) + actual = func(variable, fill_value=fill_value) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + class TestDataArray: @pytest.mark.filterwarnings("error:::pint[.*]") @pytest.mark.parametrize( @@ -3539,6 +4372,7 @@ def test_stacking_stacked(self, func, dtype): assert_equal_with_units(expected, actual) + @pytest.mark.xfail(reason="does not work with quantities yet") def test_to_stacked_array(self, dtype): labels = np.arange(5).astype(dtype) * unit_registry.s arrays = {name: np.linspace(0, 1, 10) * unit_registry.m for name in labels} From 46dfb779df2663c4b012124ae2ec193cca086c35 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Fri, 17 Jan 2020 11:21:30 -0500 Subject: [PATCH 18/23] remove DataArray and Dataset constructor deprecations for 0.15 (#3560) * remove 0.15 deprecations * whatsnew --- doc/whats-new.rst | 5 +++++ xarray/core/dataarray.py | 16 +--------------- xarray/core/dataset.py | 14 +------------- 3 files changed, 7 insertions(+), 28 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 17c679d27b5..a60419f89a3 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -22,6 +22,11 @@ v0.15.0 (unreleased) Breaking changes ~~~~~~~~~~~~~~~~ +- Remove ``compat`` and ``encoding`` kwargs from ``DataArray``, which + have been deprecated since 0.12. (:pull:`3650`). + Instead, specify the encoding when writing to disk or set + the ``encoding`` attribute directly. + By `Maximilian Roos `_ New Features ~~~~~~~~~~~~ diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 15db6ed468e..0e67a791834 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -267,8 +267,6 @@ def __init__( dims: Union[Hashable, Sequence[Hashable], None] = None, name: Hashable = None, attrs: Mapping = None, - # deprecated parameters - encoding=None, # internal parameters indexes: Dict[Hashable, pd.Index] = None, fastpath: bool = False, @@ -313,20 +311,10 @@ def __init__( Attributes to assign to the new instance. By default, an empty attribute dictionary is initialized. """ - if encoding is not None: - warnings.warn( - "The `encoding` argument to `DataArray` is deprecated, and . " - "will be removed in 0.15. " - "Instead, specify the encoding when writing to disk or " - "set the `encoding` attribute directly.", - FutureWarning, - stacklevel=2, - ) if fastpath: variable = data assert dims is None assert attrs is None - assert encoding is None else: # try to fill in arguments from data if they weren't supplied if coords is None: @@ -348,13 +336,11 @@ def __init__( name = getattr(data, "name", None) if attrs is None and not isinstance(data, PANDAS_TYPES): attrs = getattr(data, "attrs", None) - if encoding is None: - encoding = getattr(data, "encoding", None) data = _check_data_shape(data, coords, dims) data = as_compatible_data(data) coords, dims = _infer_coords_and_dims(data.shape, coords, dims) - variable = Variable(dims, data, attrs, encoding, fastpath=True) + variable = Variable(dims, data, attrs, fastpath=True) indexes = dict( _extract_indexes_from_coords(coords) ) # needed for to_dataset diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 79f1030fabe..4564dc80a8a 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -463,7 +463,6 @@ def __init__( data_vars: Mapping[Hashable, Any] = None, coords: Mapping[Hashable, Any] = None, attrs: Mapping[Hashable, Any] = None, - compat=None, ): """To load data from a file or file-like object, use the `open_dataset` function. @@ -513,18 +512,7 @@ def __init__( attrs : dict-like, optional Global attributes to save on this dataset. - compat : deprecated """ - if compat is not None: - warnings.warn( - "The `compat` argument to Dataset is deprecated and will be " - "removed in 0.15." - "Instead, use `merge` to control how variables are combined", - FutureWarning, - stacklevel=2, - ) - else: - compat = "broadcast_equals" # TODO(shoyer): expose indexes as a public argument in __init__ @@ -544,7 +532,7 @@ def __init__( coords = coords.variables variables, coord_names, dims, indexes = merge_data_and_coords( - data_vars, coords, compat=compat + data_vars, coords, compat="broadcast_equals" ) self._attrs = dict(attrs) if attrs is not None else None From 924c05212223defe0720c9da7cdb7583c2c5b696 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 17 Jan 2020 18:51:50 +0000 Subject: [PATCH 19/23] Bump mypy to v0.761 (#3704) --- .pre-commit-config.yaml | 2 +- ci/requirements/py36-min-all-deps.yml | 2 +- ci/requirements/py36.yml | 2 +- ci/requirements/py37-windows.yml | 2 +- ci/requirements/py37.yml | 2 +- xarray/core/computation.py | 2 +- xarray/core/dataset.py | 6 +++--- 7 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 502120cd5dc..ed62c1c256e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,7 @@ repos: hooks: - id: flake8 - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.730 # Must match ci/requirements/*.yml + rev: v0.761 # Must match ci/requirements/*.yml hooks: - id: mypy # run these occasionally, ref discussion https://github.com/pydata/xarray/pull/3194 diff --git a/ci/requirements/py36-min-all-deps.yml b/ci/requirements/py36-min-all-deps.yml index 3f10a158f91..e7756172311 100644 --- a/ci/requirements/py36-min-all-deps.yml +++ b/ci/requirements/py36-min-all-deps.yml @@ -25,7 +25,7 @@ dependencies: - iris=2.2 - lxml=4.4 # Optional dep of pydap - matplotlib=3.1 - - mypy=0.730 # Must match .pre-commit-config.yaml + - mypy=0.761 # Must match .pre-commit-config.yaml - nc-time-axis=1.2 - netcdf4=1.4 - numba=0.44 diff --git a/ci/requirements/py36.yml b/ci/requirements/py36.yml index 820160b19cc..7450fafbd86 100644 --- a/ci/requirements/py36.yml +++ b/ci/requirements/py36.yml @@ -21,7 +21,7 @@ dependencies: - iris - lxml # optional dep of pydap - matplotlib - - mypy=0.730 # Must match .pre-commit-config.yaml + - mypy=0.761 # Must match .pre-commit-config.yaml - nc-time-axis - netcdf4 - numba diff --git a/ci/requirements/py37-windows.yml b/ci/requirements/py37-windows.yml index 614a3bb1fab..d9e634c74ae 100644 --- a/ci/requirements/py37-windows.yml +++ b/ci/requirements/py37-windows.yml @@ -21,7 +21,7 @@ dependencies: - iris - lxml # Optional dep of pydap - matplotlib - - mypy=0.730 # Must match .pre-commit-config.yaml + - mypy=0.761 # Must match .pre-commit-config.yaml - nc-time-axis - netcdf4 - numba diff --git a/ci/requirements/py37.yml b/ci/requirements/py37.yml index 4a7aaf7d32b..2f879e29f87 100644 --- a/ci/requirements/py37.yml +++ b/ci/requirements/py37.yml @@ -21,7 +21,7 @@ dependencies: - iris - lxml # Optional dep of pydap - matplotlib - - mypy=0.730 # Must match .pre-commit-config.yaml + - mypy=0.761 # Must match .pre-commit-config.yaml - nc-time-axis - netcdf4 - numba diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 643c1137d6c..34de5edefc5 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -304,7 +304,7 @@ def _as_variables_or_variable(arg): def _unpack_dict_tuples( result_vars: Mapping[Hashable, Tuple[Variable, ...]], num_outputs: int ) -> Tuple[Dict[Hashable, Variable], ...]: - out = tuple({} for _ in range(num_outputs)) # type: ignore + out: Tuple[Dict[Hashable, Variable], ...] = tuple({} for _ in range(num_outputs)) for name, values in result_vars.items(): for value, results_dict in zip(values, out): results_dict[name] = value diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 4564dc80a8a..92a64508104 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -897,11 +897,11 @@ def _replace( if dims is not None: self._dims = dims if attrs is not _default: - self._attrs = attrs # type: ignore # FIXME need mypy 0.750 + self._attrs = attrs if indexes is not _default: - self._indexes = indexes # type: ignore # FIXME need mypy 0.750 + self._indexes = indexes if encoding is not _default: - self._encoding = encoding # type: ignore # FIXME need mypy 0.750 + self._encoding = encoding obj = self else: if variables is None: From e646dd987f414e9c1b12b106e17ba650762fa6b6 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 17 Jan 2020 20:42:24 +0000 Subject: [PATCH 20/23] hardcoded xarray.__all__ (#3703) * Fix mypy --strict * isort * mypy version bump * trivial * testing.__all__ and ufuncs.__all__ * isort * Revert mypy version bump * Revert isort --- doc/whats-new.rst | 2 ++ xarray/__init__.py | 53 ++++++++++++++++++++++++++++++++++- xarray/testing.py | 7 +++++ xarray/ufuncs.py | 70 ++++++++++++++++++++++++++++++++++++++++------ 4 files changed, 123 insertions(+), 9 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index a60419f89a3..fef1b988f01 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -87,6 +87,8 @@ Bug fixes - Fix a regression in :py:meth:`Dataset.drop`: allow passing any iterable when dropping variables (:issue:`3552`, :pull:`3693`) By `Justus Magin `_. +- Fixed errors emitted by ``mypy --strict`` in modules that import xarray. + (:issue:`3695`) by `Guido Imperiale `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/__init__.py b/xarray/__init__.py index 394dd0f80bc..0b44c293b3f 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -1,5 +1,4 @@ """ isort:skip_file """ -# flake8: noqa from ._version import get_versions @@ -42,3 +41,55 @@ from . import testing from .core.common import ALL_DIMS + +# A hardcoded __all__ variable is necessary to appease +# `mypy --strict` running in projects that import xarray. +__all__ = ( + # Sub-packages + "ufuncs", + "testing", + "tutorial", + # Top-level functions + "align", + "apply_ufunc", + "as_variable", + "auto_combine", + "broadcast", + "cftime_range", + "combine_by_coords", + "combine_nested", + "concat", + "decode_cf", + "dot", + "full_like", + "load_dataarray", + "load_dataset", + "map_blocks", + "merge", + "ones_like", + "open_dataarray", + "open_dataset", + "open_mfdataset", + "open_rasterio", + "open_zarr", + "register_dataarray_accessor", + "register_dataset_accessor", + "save_mfdataset", + "set_options", + "show_versions", + "where", + "zeros_like", + # Classes + "CFTimeIndex", + "Coordinate", + "DataArray", + "Dataset", + "IndexVariable", + "Variable", + # Exceptions + "MergeError", + "SerializationWarning", + # Constants + "__version__", + "ALL_DIMS", +) diff --git a/xarray/testing.py b/xarray/testing.py index 5c3ca8a3cca..ac189f7e023 100644 --- a/xarray/testing.py +++ b/xarray/testing.py @@ -10,6 +10,13 @@ from xarray.core.indexes import default_indexes from xarray.core.variable import IndexVariable, Variable +__all__ = ( + "assert_allclose", + "assert_chunks_equal", + "assert_equal", + "assert_identical", +) + def _decode_string_data(data): if data.dtype.kind == "S": diff --git a/xarray/ufuncs.py b/xarray/ufuncs.py index ae2c5c574b6..8ab2b7cfe31 100644 --- a/xarray/ufuncs.py +++ b/xarray/ufuncs.py @@ -132,14 +132,68 @@ def _create_op(name): return func -__all__ = """logaddexp logaddexp2 conj exp log log2 log10 log1p expm1 sqrt - square sin cos tan arcsin arccos arctan arctan2 hypot sinh cosh - tanh arcsinh arccosh arctanh deg2rad rad2deg logical_and - logical_or logical_xor logical_not maximum minimum fmax fmin - isreal iscomplex isfinite isinf isnan signbit copysign nextafter - ldexp fmod floor ceil trunc degrees radians rint fix angle real - imag fabs sign frexp fmod - """.split() +__all__ = ( # noqa: F822 + "angle", + "arccos", + "arccosh", + "arcsin", + "arcsinh", + "arctan", + "arctan2", + "arctanh", + "ceil", + "conj", + "copysign", + "cos", + "cosh", + "deg2rad", + "degrees", + "exp", + "expm1", + "fabs", + "fix", + "floor", + "fmax", + "fmin", + "fmod", + "fmod", + "frexp", + "hypot", + "imag", + "iscomplex", + "isfinite", + "isinf", + "isnan", + "isreal", + "ldexp", + "log", + "log10", + "log1p", + "log2", + "logaddexp", + "logaddexp2", + "logical_and", + "logical_not", + "logical_or", + "logical_xor", + "maximum", + "minimum", + "nextafter", + "rad2deg", + "radians", + "real", + "rint", + "sign", + "signbit", + "sin", + "sinh", + "sqrt", + "square", + "tan", + "tanh", + "trunc", +) + for name in __all__: globals()[name] = _create_op(name) From c32e58b4fff72816c6b554db51509bea6a891cdc Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 17 Jan 2020 21:00:23 +0000 Subject: [PATCH 21/23] One-off isort run (#3705) * isort run * Fix docs regression --- doc/conf.py | 2 +- properties/test_pandas_roundtrip.py | 16 ++++++++-------- xarray/core/dataset.py | 2 +- xarray/core/duck_array_ops.py | 2 +- xarray/core/parallel.py | 2 +- xarray/tests/test_accessor_dt.py | 3 +-- xarray/tests/test_units.py | 1 + 7 files changed, 14 insertions(+), 14 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index 11abda6bb63..578f9cf550d 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -25,7 +25,7 @@ os.environ["PYTHONPATH"] = str(root) sys.path.insert(0, str(root)) -import xarray +import xarray # isort:skip allowed_failures = set() diff --git a/properties/test_pandas_roundtrip.py b/properties/test_pandas_roundtrip.py index a8005d319d6..5fc097f1f5e 100644 --- a/properties/test_pandas_roundtrip.py +++ b/properties/test_pandas_roundtrip.py @@ -1,20 +1,20 @@ """ Property-based tests for roundtripping between xarray and pandas objects. """ -import pytest - -pytest.importorskip("hypothesis") - from functools import partial -import hypothesis.extra.numpy as npst -import hypothesis.extra.pandas as pdst -import hypothesis.strategies as st -from hypothesis import given import numpy as np import pandas as pd +import pytest + import xarray as xr +pytest.importorskip("hypothesis") +import hypothesis.extra.numpy as npst # isort:skip +import hypothesis.extra.pandas as pdst # isort:skip +import hypothesis.strategies as st # isort:skip +from hypothesis import given # isort:skip + numeric_dtypes = st.one_of( npst.unsigned_integer_dtypes(), npst.integer_dtypes(), npst.floating_dtypes() ) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 92a64508104..82ddc8a535f 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -92,8 +92,8 @@ IndexVariable, Variable, as_variable, - broadcast_variables, assert_unique_multiindex_level_names, + broadcast_variables, ) if TYPE_CHECKING: diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 98b371ab7c3..3fa97f0b448 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -11,7 +11,7 @@ import numpy as np import pandas as pd -from . import dask_array_ops, dask_array_compat, dtypes, npcompat, nputils +from . import dask_array_compat, dask_array_ops, dtypes, npcompat, nputils from .nputils import nanfirst, nanlast from .pycompat import dask_array_type diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index e4fb5803191..facfa06b23c 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -13,8 +13,8 @@ from typing import ( Any, Callable, - Dict, DefaultDict, + Dict, Hashable, Mapping, Sequence, diff --git a/xarray/tests/test_accessor_dt.py b/xarray/tests/test_accessor_dt.py index 67ca12532c7..f178720a6e1 100644 --- a/xarray/tests/test_accessor_dt.py +++ b/xarray/tests/test_accessor_dt.py @@ -11,8 +11,7 @@ requires_cftime, requires_dask, ) - -from .test_dask import raise_if_dask_computes, assert_chunks_equal +from .test_dask import assert_chunks_equal, raise_if_dask_computes class TestDatetimeAccessor: diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 2cb1550c088..6f9128839a7 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -7,6 +7,7 @@ import xarray as xr from xarray.core import formatting from xarray.core.npcompat import IS_NEP18_ACTIVE + from .test_variable import VariableSubclassobjects pint = pytest.importorskip("pint") From 5c97641dd129a9dcc4346a137bcf32a5a982d37a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kai=20M=C3=BChlbauer?= Date: Mon, 20 Jan 2020 06:46:13 +0100 Subject: [PATCH 22/23] =?UTF-8?q?ENH:=20enable=20`H5NetCDFStore`=20to=20wo?= =?UTF-8?q?rk=20with=20already=20open=20h5netcdf.File=20a=E2=80=A6=20(#361?= =?UTF-8?q?8)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ENH: enable `H5NetCDFStore` to work with already open h5netcdf.File and h5netcdf.Group objects, add test * FIX: add `.open` method for file-like objects * FIX: reformat using black * WIP: add item to whats-new.rst * FIX: temporary fix to tackle issue #3680 * FIX: do not use private API, use find_root_and_group instead * FIX: reformat using black --- doc/whats-new.rst | 5 +++ xarray/backends/api.py | 6 ++-- xarray/backends/common.py | 2 +- xarray/backends/h5netcdf_.py | 61 ++++++++++++++++++++++++++--------- xarray/tests/test_backends.py | 23 ++++++++++++- 5 files changed, 77 insertions(+), 20 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index fef1b988f01..f9e2e9270b0 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -30,6 +30,11 @@ Breaking changes New Features ~~~~~~~~~~~~ +- Support using an existing, opened h5netcdf ``File`` with + :py:class:`~xarray.backends.H5NetCDFStore`. This permits creating an + :py:class:`~xarray.Dataset` from a h5netcdf ``File`` that has been opened + using other means (:issue:`3618`). + By `Kai Mühlbauer `_. - Implement :py:func:`median` and :py:func:`nanmedian` for dask arrays. This works by rechunking to a single chunk along all reduction axes. (:issue:`2999`). By `Deepak Cherian `_. diff --git a/xarray/backends/api.py b/xarray/backends/api.py index eea1fb9ddce..0990fc46219 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -503,7 +503,7 @@ def maybe_decode_store(store, lock=False): elif engine == "pydap": store = backends.PydapDataStore.open(filename_or_obj, **backend_kwargs) elif engine == "h5netcdf": - store = backends.H5NetCDFStore( + store = backends.H5NetCDFStore.open( filename_or_obj, group=group, lock=lock, **backend_kwargs ) elif engine == "pynio": @@ -527,7 +527,7 @@ def maybe_decode_store(store, lock=False): if engine == "scipy": store = backends.ScipyDataStore(filename_or_obj, **backend_kwargs) elif engine == "h5netcdf": - store = backends.H5NetCDFStore( + store = backends.H5NetCDFStore.open( filename_or_obj, group=group, lock=lock, **backend_kwargs ) @@ -981,7 +981,7 @@ def open_mfdataset( WRITEABLE_STORES: Dict[str, Callable] = { "netcdf4": backends.NetCDF4DataStore.open, "scipy": backends.ScipyDataStore, - "h5netcdf": backends.H5NetCDFStore, + "h5netcdf": backends.H5NetCDFStore.open, } diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 9786e0a0203..fa3ee19f542 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -34,7 +34,7 @@ def find_root_and_group(ds): """Find the root and group name of a netCDF4/h5netcdf dataset.""" hierarchy = () while ds.parent is not None: - hierarchy = (ds.name,) + hierarchy + hierarchy = (ds.name.split("/")[-1],) + hierarchy ds = ds.parent group = "/" + "/".join(hierarchy) return ds, group diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index 51ed512f98b..203213afe82 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -4,9 +4,9 @@ from .. import Variable from ..core import indexing -from ..core.utils import FrozenDict -from .common import WritableCFDataStore -from .file_manager import CachingFileManager +from ..core.utils import FrozenDict, is_remote_uri +from .common import WritableCFDataStore, find_root_and_group +from .file_manager import CachingFileManager, DummyFileManager from .locks import HDF5_LOCK, combine_locks, ensure_lock, get_write_lock from .netCDF4_ import ( BaseNetCDF4Array, @@ -69,8 +69,47 @@ class H5NetCDFStore(WritableCFDataStore): """Store for reading and writing data via h5netcdf """ - def __init__( - self, + __slots__ = ( + "autoclose", + "format", + "is_remote", + "lock", + "_filename", + "_group", + "_manager", + "_mode", + ) + + def __init__(self, manager, group=None, mode=None, lock=HDF5_LOCK, autoclose=False): + + import h5netcdf + + if isinstance(manager, (h5netcdf.File, h5netcdf.Group)): + if group is None: + root, group = find_root_and_group(manager) + else: + if not type(manager) is h5netcdf.File: + raise ValueError( + "must supply a h5netcdf.File if the group " + "argument is provided" + ) + root = manager + manager = DummyFileManager(root) + + self._manager = manager + self._group = group + self._mode = mode + self.format = None + # todo: utilizing find_root_and_group seems a bit clunky + # making filename available on h5netcdf.Group seems better + self._filename = find_root_and_group(self.ds)[0].filename + self.is_remote = is_remote_uri(self._filename) + self.lock = ensure_lock(lock) + self.autoclose = autoclose + + @classmethod + def open( + cls, filename, mode="r", format=None, @@ -86,22 +125,14 @@ def __init__( kwargs = {"invalid_netcdf": invalid_netcdf} - self._manager = CachingFileManager( - h5netcdf.File, filename, mode=mode, kwargs=kwargs - ) - if lock is None: if mode == "r": lock = HDF5_LOCK else: lock = combine_locks([HDF5_LOCK, get_write_lock(filename)]) - self._group = group - self.format = format - self._filename = filename - self._mode = mode - self.lock = ensure_lock(lock) - self.autoclose = autoclose + manager = CachingFileManager(h5netcdf.File, filename, mode=mode, kwargs=kwargs) + return cls(manager, group=group, mode=mode, lock=lock, autoclose=autoclose) def _acquire(self, needs_lock=True): with self._manager.acquire_context(needs_lock) as root: diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 4fccdf2dd6c..0436ae9d244 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -2182,7 +2182,7 @@ class TestH5NetCDFData(NetCDF4Base): @contextlib.contextmanager def create_store(self): with create_tmp_file() as tmp_file: - yield backends.H5NetCDFStore(tmp_file, "w") + yield backends.H5NetCDFStore.open(tmp_file, "w") @pytest.mark.filterwarnings("ignore:complex dtypes are supported by h5py") @pytest.mark.parametrize( @@ -2345,6 +2345,27 @@ def test_dump_encodings_h5py(self): assert actual.x.encoding["compression"] == "lzf" assert actual.x.encoding["compression_opts"] is None + def test_already_open_dataset_group(self): + import h5netcdf + + with create_tmp_file() as tmp_file: + with nc4.Dataset(tmp_file, mode="w") as nc: + group = nc.createGroup("g") + v = group.createVariable("x", "int") + v[...] = 42 + + h5 = h5netcdf.File(tmp_file, mode="r") + store = backends.H5NetCDFStore(h5["g"]) + with open_dataset(store) as ds: + expected = Dataset({"x": ((), 42)}) + assert_identical(expected, ds) + + h5 = h5netcdf.File(tmp_file, mode="r") + store = backends.H5NetCDFStore(h5, group="g") + with open_dataset(store) as ds: + expected = Dataset({"x": ((), 42)}) + assert_identical(expected, ds) + @requires_h5netcdf class TestH5NetCDFFileObject(TestH5NetCDFData): From aa0f96383062c48cb17f46ef951075c0494c9c0a Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Mon, 20 Jan 2020 13:09:26 +0100 Subject: [PATCH 23/23] Feature/align in dot (#3699) * add tests * implement align * whats new * fix changes to whats new * review: fix typos --- doc/whats-new.rst | 5 ++- xarray/core/computation.py | 7 ++++ xarray/tests/test_computation.py | 54 +++++++++++++++++++++++++++++++ xarray/tests/test_dataarray.py | 55 ++++++++++++++++++++++++++++++++ 4 files changed, 120 insertions(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index f9e2e9270b0..25133df8a29 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -21,12 +21,15 @@ v0.15.0 (unreleased) Breaking changes ~~~~~~~~~~~~~~~~ - - Remove ``compat`` and ``encoding`` kwargs from ``DataArray``, which have been deprecated since 0.12. (:pull:`3650`). Instead, specify the encoding when writing to disk or set the ``encoding`` attribute directly. By `Maximilian Roos `_ +- :py:func:`xarray.dot`, :py:meth:`DataArray.dot`, and the ``@`` operator now + use ``align="inner"`` (except when ``xarray.set_options(arithmetic_join="exact")``; + :issue:`3694`) by `Mathias Hauser `_. + New Features ~~~~~~~~~~~~ diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 34de5edefc5..eb9ca8c17fc 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -26,6 +26,7 @@ from . import duck_array_ops, utils from .alignment import deep_align from .merge import merge_coordinates_without_align +from .options import OPTIONS from .pycompat import dask_array_type from .utils import is_dict_like from .variable import Variable @@ -1175,6 +1176,11 @@ def dot(*arrays, dims=None, **kwargs): subscripts = ",".join(subscripts_list) subscripts += "->..." + "".join([dim_map[d] for d in output_core_dims[0]]) + join = OPTIONS["arithmetic_join"] + # using "inner" emulates `(a * b).sum()` for all joins (except "exact") + if join != "exact": + join = "inner" + # subscripts should be passed to np.einsum as arg, not as kwargs. We need # to construct a partial function for apply_ufunc to work. func = functools.partial(duck_array_ops.einsum, subscripts, **kwargs) @@ -1183,6 +1189,7 @@ def dot(*arrays, dims=None, **kwargs): *arrays, input_core_dims=input_core_dims, output_core_dims=output_core_dims, + join=join, dask="allowed", ) return result.transpose(*[d for d in all_dims if d in result.dims]) diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 1f2634cc9b0..2d373d12095 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -1043,6 +1043,60 @@ def test_dot(use_dask): pickle.loads(pickle.dumps(xr.dot(da_a))) +@pytest.mark.parametrize("use_dask", [True, False]) +def test_dot_align_coords(use_dask): + # GH 3694 + + if use_dask: + if not has_dask: + pytest.skip("test for dask.") + + a = np.arange(30 * 4).reshape(30, 4) + b = np.arange(30 * 4 * 5).reshape(30, 4, 5) + + # use partially overlapping coords + coords_a = {"a": np.arange(30), "b": np.arange(4)} + coords_b = {"a": np.arange(5, 35), "b": np.arange(1, 5)} + + da_a = xr.DataArray(a, dims=["a", "b"], coords=coords_a) + da_b = xr.DataArray(b, dims=["a", "b", "c"], coords=coords_b) + + if use_dask: + da_a = da_a.chunk({"a": 3}) + da_b = da_b.chunk({"a": 3}) + + # join="inner" is the default + actual = xr.dot(da_a, da_b) + # `dot` sums over the common dimensions of the arguments + expected = (da_a * da_b).sum(["a", "b"]) + xr.testing.assert_allclose(expected, actual) + + actual = xr.dot(da_a, da_b, dims=...) + expected = (da_a * da_b).sum() + xr.testing.assert_allclose(expected, actual) + + with xr.set_options(arithmetic_join="exact"): + with raises_regex(ValueError, "indexes along dimension"): + xr.dot(da_a, da_b) + + # NOTE: dot always uses `join="inner"` because `(a * b).sum()` yields the same for all + # join method (except "exact") + with xr.set_options(arithmetic_join="left"): + actual = xr.dot(da_a, da_b) + expected = (da_a * da_b).sum(["a", "b"]) + xr.testing.assert_allclose(expected, actual) + + with xr.set_options(arithmetic_join="right"): + actual = xr.dot(da_a, da_b) + expected = (da_a * da_b).sum(["a", "b"]) + xr.testing.assert_allclose(expected, actual) + + with xr.set_options(arithmetic_join="outer"): + actual = xr.dot(da_a, da_b) + expected = (da_a * da_b).sum(["a", "b"]) + xr.testing.assert_allclose(expected, actual) + + def test_where(): cond = xr.DataArray([True, False], dims="x") actual = xr.where(cond, 1, 0) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 786eb5007a6..962be7548bc 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -3973,6 +3973,43 @@ def test_dot(self): with pytest.raises(TypeError): da.dot(dm.values) + def test_dot_align_coords(self): + # GH 3694 + + x = np.linspace(-3, 3, 6) + y = np.linspace(-3, 3, 5) + z_a = range(4) + da_vals = np.arange(6 * 5 * 4).reshape((6, 5, 4)) + da = DataArray(da_vals, coords=[x, y, z_a], dims=["x", "y", "z"]) + + z_m = range(2, 6) + dm_vals = range(4) + dm = DataArray(dm_vals, coords=[z_m], dims=["z"]) + + with xr.set_options(arithmetic_join="exact"): + with raises_regex(ValueError, "indexes along dimension"): + da.dot(dm) + + da_aligned, dm_aligned = xr.align(da, dm, join="inner") + + # nd dot 1d + actual = da.dot(dm) + expected_vals = np.tensordot(da_aligned.values, dm_aligned.values, [2, 0]) + expected = DataArray(expected_vals, coords=[x, da_aligned.y], dims=["x", "y"]) + assert_equal(expected, actual) + + # multiple shared dims + dm_vals = np.arange(20 * 5 * 4).reshape((20, 5, 4)) + j = np.linspace(-3, 3, 20) + dm = DataArray(dm_vals, coords=[j, y, z_m], dims=["j", "y", "z"]) + da_aligned, dm_aligned = xr.align(da, dm, join="inner") + actual = da.dot(dm) + expected_vals = np.tensordot( + da_aligned.values, dm_aligned.values, axes=([1, 2], [1, 2]) + ) + expected = DataArray(expected_vals, coords=[x, j], dims=["x", "j"]) + assert_equal(expected, actual) + def test_matmul(self): # copied from above (could make a fixture) @@ -3986,6 +4023,24 @@ def test_matmul(self): expected = da.dot(da) assert_identical(result, expected) + def test_matmul_align_coords(self): + # GH 3694 + + x_a = np.arange(6) + x_b = np.arange(2, 8) + da_vals = np.arange(6) + da_a = DataArray(da_vals, coords=[x_a], dims=["x"]) + da_b = DataArray(da_vals, coords=[x_b], dims=["x"]) + + # only test arithmetic_join="inner" (=default) + result = da_a @ da_b + expected = da_a.dot(da_b) + assert_identical(result, expected) + + with xr.set_options(arithmetic_join="exact"): + with raises_regex(ValueError, "indexes along dimension"): + da_a @ da_b + def test_binary_op_propagate_indexes(self): # regression test for GH2227 self.dv["x"] = np.arange(self.dv.sizes["x"])