From 8931c695c606927e8f80dea26f5c851ef926810e Mon Sep 17 00:00:00 2001 From: Hasan Ahmad <32473508+HasanAhmadQ7@users.noreply.github.com> Date: Mon, 5 Aug 2019 02:41:32 +0900 Subject: [PATCH 01/10] =?UTF-8?q?BUG:=20fix=20+=20test=20open=5Fmfdataset?= =?UTF-8?q?=20fails=20on=20variable=20attributes=20with=20list=E2=80=A6=20?= =?UTF-8?q?(#3181)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * BUG: fix + test open_mfdataset fails on variable attributes with list type Using open_mfdataset on a series of netcdf files having variable attributes with type list will fail with the following exception, when these attributes have different values from one file to another: solves: #3034 * DOC: add #3034 bug fix to whats new * Update xarray/core/utils.py Co-Authored-By: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> * Update xarray/core/utils.py Co-Authored-By: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> * BUG: fix + test open_mfdataset fails on variable attributes with list type Using open_mfdataset on a series of netcdf files having variable attributes with type list will fail with the following exception, when these attributes have different values from one file to another: solves: #3034 * DOC: add #3034 bug fix to whats new * Update xarray/core/utils.py Co-Authored-By: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> * Update xarray/core/utils.py Co-Authored-By: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> * DOC: update whats' new --- doc/whats-new.rst | 10 ++++++---- xarray/core/utils.py | 15 ++++++++++++++- xarray/tests/test_backends.py | 24 ++++++++++++++++++++++++ 3 files changed, 44 insertions(+), 5 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 1789fcea118..7de6794d5b4 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -61,8 +61,12 @@ Bug fixes By `Tom Nicholas `_. - Fixed crash when applying ``distributed.Client.compute()`` to a DataArray (:issue:`3171`). By `Guido Imperiale `_. - - +- Better error message when using groupby on an empty DataArray (:issue:`3037`). + By `Hasan Ahmad `_. +- Fix error that arises when using open_mfdataset on a series of netcdf files + having differing values for a variable attribute of type list. (:issue:`3034`) + By `Hasan Ahmad `_. + .. _whats-new.0.12.3: v0.12.3 (10 July 2019) @@ -103,8 +107,6 @@ Bug fixes - Fix HDF5 error that could arise when reading multiple groups from a file at once (:issue:`2954`). By `Stephan Hoyer `_. -- Better error message when using groupby on an empty DataArray (:issue:`3037`). - By `Hasan Ahmad `_. .. _whats-new.0.12.2: diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 08c911f5e4a..60e0fe1e7d7 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -129,18 +129,31 @@ def maybe_wrap_array(original, new_array): def equivalent(first: T, second: T) -> bool: """Compare two objects for equivalence (identity or equality), using - array_equiv if either object is an ndarray + array_equiv if either object is an ndarray. If both objects are lists, + equivalent is sequentially called on all the elements. """ # TODO: refactor to avoid circular import from . import duck_array_ops if isinstance(first, np.ndarray) or isinstance(second, np.ndarray): return duck_array_ops.array_equiv(first, second) + elif isinstance(first, list) or isinstance(second, list): + return list_equiv(first, second) else: return ((first is second) or (first == second) or (pd.isnull(first) and pd.isnull(second))) +def list_equiv(first, second): + equiv = True + if len(first) != len(second): + return False + else: + for f, s in zip(first, second): + equiv = equiv and equivalent(f, s) + return equiv + + def peek_at(iterable: Iterable[T]) -> Tuple[T, Iterator[T]]: """Returns the first value from iterable, as well as a new iterator with the same content as the original iterable diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 539703962c7..026ae6a55ff 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -2361,6 +2361,30 @@ def test_open_mfdataset_manyfiles(readengine, nfiles, parallel, chunks, assert_identical(original, actual) +@requires_netCDF4 +def test_open_mfdataset_list_attr(): + """ + Case when an attribute of type list differs across the multiple files + """ + from netCDF4 import Dataset + with create_tmp_files(2) as nfiles: + for i in range(2): + f = Dataset(nfiles[i], "w") + f.createDimension("x", 3) + vlvar = f.createVariable("test_var", np.int32, ("x")) + # here create an attribute as a list + vlvar.test_attr = ["string a {}".format(i), + "string b {}".format(i)] + vlvar[:] = np.arange(3) + f.close() + ds1 = open_dataset(nfiles[0]) + ds2 = open_dataset(nfiles[1]) + original = xr.concat([ds1, ds2], dim='x') + with xr.open_mfdataset([nfiles[0], nfiles[1]], combine='nested', + concat_dim='x') as actual: + assert_identical(actual, original) + + @requires_scipy_or_netCDF4 @requires_dask class TestOpenMFDatasetWithDataVarsAndCoordsKw: From 298d532f876193ab76ff1f820b6e2b512becf92d Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sun, 4 Aug 2019 23:00:58 +0000 Subject: [PATCH 02/10] Call darray.compute() in plot() (#3183) * Call darray.compute() in plot() * review. --- xarray/plot/plot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 9d0bf671dda..d0003b702df 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -139,7 +139,7 @@ def plot(darray, row=None, col=None, col_wrap=None, ax=None, hue=None, Additional keyword arguments to matplotlib """ - darray = darray.squeeze() + darray = darray.squeeze().compute() plot_dims = set(darray.dims) plot_dims.discard(row) From 50f897058e361ba4d9f2ebf048a31b18a7521c2b Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 4 Aug 2019 20:29:19 -0700 Subject: [PATCH 03/10] Internal clean-up of isnull() to avoid relying on pandas (#3132) * Internal clean-up of isnull() to avoid relying on pandas This version should be much more compatible out of the box with duck typing. * Use isnat ufunc * update comment --- properties/test_encode_decode.py | 2 -- xarray/core/duck_array_ops.py | 52 +++++++++++++++++++++++------ xarray/tests/test_duck_array_ops.py | 25 +++++++++++--- 3 files changed, 63 insertions(+), 16 deletions(-) diff --git a/properties/test_encode_decode.py b/properties/test_encode_decode.py index 13f63f259cf..4b9aa8928b4 100644 --- a/properties/test_encode_decode.py +++ b/properties/test_encode_decode.py @@ -4,8 +4,6 @@ These ones pass, just as you'd hope! """ -from __future__ import absolute_import, division, print_function - import hypothesis.extra.numpy as npst import hypothesis.strategies as st from hypothesis import given, settings diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index c4db95cfd4e..ac204df568f 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -63,19 +63,51 @@ def fail_on_dask_array_input(values, msg=None, func_name=None): around = _dask_or_eager_func('around') isclose = _dask_or_eager_func('isclose') -notnull = _dask_or_eager_func('notnull', eager_module=pd) -_isnull = _dask_or_eager_func('isnull', eager_module=pd) + +if hasattr(np, 'isnat') and ( + dask_array is None or hasattr(dask_array_type, '__array_ufunc__')): + # np.isnat is available since NumPy 1.13, so __array_ufunc__ is always + # supported. + isnat = np.isnat +else: + isnat = _dask_or_eager_func('isnull', eager_module=pd) +isnan = _dask_or_eager_func('isnan') +zeros_like = _dask_or_eager_func('zeros_like') + + +pandas_isnull = _dask_or_eager_func('isnull', eager_module=pd) def isnull(data): - # GH837, GH861 - # isnull fcn from pandas will throw TypeError when run on numpy structured - # array therefore for dims that are np structured arrays we assume all - # data is present - try: - return _isnull(data) - except TypeError: - return np.zeros(data.shape, dtype=bool) + data = asarray(data) + scalar_type = data.dtype.type + if issubclass(scalar_type, (np.datetime64, np.timedelta64)): + # datetime types use NaT for null + # note: must check timedelta64 before integers, because currently + # timedelta64 inherits from np.integer + return isnat(data) + elif issubclass(scalar_type, np.inexact): + # float types use NaN for null + return isnan(data) + elif issubclass( + scalar_type, (np.bool_, np.integer, np.character, np.void) + ): + # these types cannot represent missing values + return zeros_like(data, dtype=bool) + else: + # at this point, array should have dtype=object + if isinstance(data, (np.ndarray, dask_array_type)): + return pandas_isnull(data) + else: + # Not reachable yet, but intended for use with other duck array + # types. For full consistency with pandas, we should accept None as + # a null value as well as NaN, but it isn't clear how to do this + # with duck typing. + return data != data + + +def notnull(data): + return ~isnull(data) transpose = _dask_or_eager_func('transpose') diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index b0b7cfbd943..bb22a15e227 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -178,14 +178,18 @@ def test_wrong_shape(self): assert not array_notnull_equiv(a, b) @pytest.mark.parametrize("val1, val2, val3, null", [ - (1, 2, 3, None), + (np.datetime64('2000'), + np.datetime64('2001'), + np.datetime64('2002'), + np.datetime64('NaT')), (1., 2., 3., np.nan), - (1., 2., 3., None), ('foo', 'bar', 'baz', None), + ('foo', 'bar', 'baz', np.nan), ]) def test_types(self, val1, val2, val3, null): - arr1 = np.array([val1, null, val3, null]) - arr2 = np.array([val1, val2, null, null]) + dtype = object if isinstance(val1, str) else None + arr1 = np.array([val1, null, val3, null], dtype=dtype) + arr2 = np.array([val1, val2, null, null], dtype=dtype) assert array_notnull_equiv(arr1, arr2) @@ -432,6 +436,19 @@ def test_argmin_max_error(): da.argmin(dim='y') +@pytest.mark.parametrize('array', [ + np.array([np.datetime64('2000-01-01'), np.datetime64('NaT')]), + np.array([np.timedelta64(1, 'h'), np.timedelta64('NaT')]), + np.array([0.0, np.nan]), + np.array([1j, np.nan]), + np.array(['foo', np.nan], dtype=object), +]) +def test_isnull(array): + expected = np.array([False, True]) + actual = duck_array_ops.isnull(array) + np.testing.assert_equal(expected, actual) + + @requires_dask def test_isnull_with_dask(): da = construct_dataarray(2, np.float32, contains_nan=True, dask=True) From bbd25c65a874dfd505ab9df43f6109f6343e5f4d Mon Sep 17 00:00:00 2001 From: Nezar Abdennur Date: Mon, 5 Aug 2019 14:44:43 -0400 Subject: [PATCH 04/10] Support for __array_function__ implementers (sparse arrays) [WIP] (#3117) * Support for __array_function__ implementers * Pep8 * Consistent naming * Check for NEP18 enabled and nep18 non-numpy arrays * Replace .values with .data * Add initial test for nep18 * Fix linting issues * Add parameterized tests * Internal clean-up of isnull() to avoid relying on pandas This version should be much more compatible out of the box with duck typing. * Add sparse to ci requirements * Moar tests * Two more patches for __array_function__ duck-arrays * Don't use coords attribute from duck-arrays that aren't derived from DataWithCoords * Improve checking for coords, and autopep8 * Skip tests if NEP-18 envvar is not set * flake8 * Update xarray/core/dataarray.py Co-Authored-By: Stephan Hoyer * Fix coords parsing * More tests * Add align tests * Replace nep18 tests with more extensive tests on pydata/sparse * Add xfails for missing np.result_type (fixed by pydata/sparse/pull/261) * Fix xpasses * Revert isnull/notnull * Fix as_like_arrays by coercing dense arrays to COO if any sparse * Make Variable.load a no-op for non-dask duck arrays * Add additional method tests * Fix utils.as_scalar to handle duck arrays with ndim>0 --- ci/requirements/py37.yml | 1 + xarray/core/dataarray.py | 7 +- xarray/core/duck_array_ops.py | 12 +- xarray/core/formatting.py | 7 +- xarray/core/indexing.py | 13 + xarray/core/npcompat.py | 15 + xarray/core/pycompat.py | 7 + xarray/core/utils.py | 4 +- xarray/core/variable.py | 39 +- xarray/tests/test_sparse.py | 689 ++++++++++++++++++++++++++++++++++ 10 files changed, 775 insertions(+), 19 deletions(-) create mode 100644 xarray/tests/test_sparse.py diff --git a/ci/requirements/py37.yml b/ci/requirements/py37.yml index f1b6d46fd95..f61aab69e0f 100644 --- a/ci/requirements/py37.yml +++ b/ci/requirements/py37.yml @@ -21,6 +21,7 @@ dependencies: - pip - scipy - seaborn + - sparse - toolz - rasterio - boto3 diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 0e28613323e..ba6477f34cc 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -260,8 +260,10 @@ def __init__( else: # try to fill in arguments from data if they weren't supplied if coords is None: - coords = getattr(data, 'coords', None) - if isinstance(data, pd.Series): + + if isinstance(data, DataArray): + coords = data.coords + elif isinstance(data, pd.Series): coords = [data.index] elif isinstance(data, pd.DataFrame): coords = [data.index, data.columns] @@ -269,6 +271,7 @@ def __init__( coords = [data] elif isinstance(data, pdcompat.Panel): coords = [data.items, data.major_axis, data.minor_axis] + if dims is None: dims = getattr(data, 'dims', getattr(coords, 'dims', None)) if name is None: diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index ac204df568f..f78ecb969a1 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -13,7 +13,7 @@ from . import dask_array_ops, dtypes, npcompat, nputils from .nputils import nanfirst, nanlast -from .pycompat import dask_array_type +from .pycompat import dask_array_type, sparse_array_type try: import dask.array as dask_array @@ -64,6 +64,7 @@ def fail_on_dask_array_input(values, msg=None, func_name=None): around = _dask_or_eager_func('around') isclose = _dask_or_eager_func('isclose') + if hasattr(np, 'isnat') and ( dask_array is None or hasattr(dask_array_type, '__array_ufunc__')): # np.isnat is available since NumPy 1.13, so __array_ufunc__ is always @@ -153,7 +154,11 @@ def trapz(y, x, axis): def asarray(data): - return data if isinstance(data, dask_array_type) else np.asarray(data) + return ( + data if (isinstance(data, dask_array_type) + or hasattr(data, '__array_function__')) + else np.asarray(data) + ) def as_shared_dtype(scalars_or_arrays): @@ -170,6 +175,9 @@ def as_shared_dtype(scalars_or_arrays): def as_like_arrays(*data): if all(isinstance(d, dask_array_type) for d in data): return data + elif any(isinstance(d, sparse_array_type) for d in data): + from sparse import COO + return tuple(COO(d) for d in data) else: return tuple(np.asarray(d) for d in data) diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 9f39acde90b..00c813ece09 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -357,7 +357,10 @@ def set_numpy_options(*args, **kwargs): def short_array_repr(array): - array = np.asarray(array) + + if not hasattr(array, '__array_function__'): + array = np.asarray(array) + # default to lower precision so a full (abbreviated) line can fit on # one line with the default display_width options = { @@ -394,7 +397,7 @@ def short_data_repr(array): if isinstance(getattr(array, 'variable', array)._data, dask_array_type): return short_dask_repr(array) elif array._in_memory or array.size < 1e5: - return short_array_repr(array.values) + return short_array_repr(array.data) else: return u'[{} values with dtype={}]'.format(array.size, array.dtype) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 14f62c533da..aea5a5a3f4f 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -657,6 +657,9 @@ def as_indexable(array): return PandasIndexAdapter(array) if isinstance(array, dask_array_type): return DaskIndexingAdapter(array) + if hasattr(array, '__array_function__'): + return NdArrayLikeIndexingAdapter(array) + raise TypeError('Invalid array type: {}'.format(type(array))) @@ -1189,6 +1192,16 @@ def __setitem__(self, key, value): raise +class NdArrayLikeIndexingAdapter(NumpyIndexingAdapter): + def __init__(self, array): + if not hasattr(array, '__array_function__'): + raise TypeError( + 'NdArrayLikeIndexingAdapter must wrap an object that ' + 'implements the __array_function__ protocol' + ) + self.array = array + + class DaskIndexingAdapter(ExplicitlyIndexedNDArrayMixin): """Wrap a dask array to support explicit indexing.""" diff --git a/xarray/core/npcompat.py b/xarray/core/npcompat.py index 856cfc4fe79..afef9a5e083 100644 --- a/xarray/core/npcompat.py +++ b/xarray/core/npcompat.py @@ -357,3 +357,18 @@ def moveaxis(a, source, destination): # https://github.com/numpy/numpy/issues/7370 # https://github.com/numpy/numpy-stubs/ DTypeLike = Union[np.dtype, str] + + +# from dask/array/utils.py +def _is_nep18_active(): + class A: + def __array_function__(self, *args, **kwargs): + return True + + try: + return np.concatenate([A()]) + except ValueError: + return False + + +IS_NEP18_ACTIVE = _is_nep18_active() diff --git a/xarray/core/pycompat.py b/xarray/core/pycompat.py index 16c5325565c..259f44f2862 100644 --- a/xarray/core/pycompat.py +++ b/xarray/core/pycompat.py @@ -8,3 +8,10 @@ dask_array_type = (dask.array.Array,) except ImportError: # pragma: no cover dask_array_type = () + +try: + # solely for isinstance checks + import sparse + sparse_array_type = (sparse.SparseArray,) +except ImportError: # pragma: no cover + sparse_array_type = () diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 60e0fe1e7d7..b3e19aebcbf 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -243,7 +243,9 @@ def is_scalar(value: Any) -> bool: return ( getattr(value, 'ndim', None) == 0 or isinstance(value, (str, bytes)) or not - isinstance(value, (Iterable, ) + dask_array_type)) + (isinstance(value, (Iterable, ) + dask_array_type) or + hasattr(value, '__array_function__')) + ) def is_valid_numpy_dtype(dtype: Any) -> bool: diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 5c6a3ad0f30..3c9d85f13d7 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -17,6 +17,7 @@ as_indexable) from .options import _get_keep_attrs from .pycompat import dask_array_type, integer_types +from .npcompat import IS_NEP18_ACTIVE from .utils import ( OrderedSet, decode_numpy_dict_values, either_dict_or_kwargs, ensure_us_time_resolution) @@ -179,6 +180,18 @@ def as_compatible_data(data, fastpath=False): else: data = np.asarray(data) + if not isinstance(data, np.ndarray): + if hasattr(data, '__array_function__'): + if IS_NEP18_ACTIVE: + return data + else: + raise TypeError( + 'Got an NumPy-like array type providing the ' + '__array_function__ protocol but NEP18 is not enabled. ' + 'Check that numpy >= v1.16 and that the environment ' + 'variable "NUMPY_EXPERIMENTAL_ARRAY_FUNCTION" is set to ' + '"1"') + # validate whether the data is valid data types data = np.asarray(data) @@ -288,7 +301,7 @@ def _in_memory(self): @property def data(self): - if isinstance(self._data, dask_array_type): + if hasattr(self._data, '__array_function__'): return self._data else: return self.values @@ -320,7 +333,7 @@ def load(self, **kwargs): """ if isinstance(self._data, dask_array_type): self._data = as_compatible_data(self._data.compute(**kwargs)) - elif not isinstance(self._data, np.ndarray): + elif not hasattr(self._data, '__array_function__'): self._data = np.asarray(self._data) return self @@ -705,8 +718,8 @@ def __setitem__(self, key, value): if new_order: value = duck_array_ops.asarray(value) - value = value[(len(dims) - value.ndim) * (np.newaxis,) + - (Ellipsis,)] + value = value[(len(dims) - value.ndim) * (np.newaxis,) + + (Ellipsis,)] value = duck_array_ops.moveaxis( value, new_order, range(len(new_order))) @@ -805,7 +818,8 @@ def copy(self, deep=True, data=None): data = indexing.MemoryCachedArray(data.array) if deep: - if isinstance(data, dask_array_type): + if (hasattr(data, '__array_function__') + or isinstance(data, dask_array_type)): data = data.copy() elif not isinstance(data, PandasIndexAdapter): # pandas.Index is immutable @@ -1494,9 +1508,10 @@ def equals(self, other, equiv=duck_array_ops.array_equiv): """ other = getattr(other, 'variable', other) try: - return (self.dims == other.dims and - (self._data is other._data or - equiv(self.data, other.data))) + return ( + self.dims == other.dims and + (self._data is other._data or equiv(self.data, other.data)) + ) except (TypeError, AttributeError): return False @@ -1517,8 +1532,8 @@ def identical(self, other): """Like equals, but also checks attributes. """ try: - return (utils.dict_equiv(self.attrs, other.attrs) and - self.equals(other)) + return (utils.dict_equiv(self.attrs, other.attrs) + and self.equals(other)) except (TypeError, AttributeError): return False @@ -1959,8 +1974,8 @@ def equals(self, other, equiv=None): # otherwise use the native index equals, rather than looking at _data other = getattr(other, 'variable', other) try: - return (self.dims == other.dims and - self._data_equals(other)) + return (self.dims == other.dims + and self._data_equals(other)) except (TypeError, AttributeError): return False diff --git a/xarray/tests/test_sparse.py b/xarray/tests/test_sparse.py new file mode 100644 index 00000000000..3aa407f72bc --- /dev/null +++ b/xarray/tests/test_sparse.py @@ -0,0 +1,689 @@ +from collections import OrderedDict +from contextlib import suppress +from distutils.version import LooseVersion +from textwrap import dedent +import pickle +import numpy as np +import pandas as pd + +from xarray import DataArray, Dataset, Variable +from xarray.tests import mock +from xarray.core.npcompat import IS_NEP18_ACTIVE +import xarray as xr +import xarray.ufuncs as xu + +from . import ( + assert_allclose, assert_array_equal, assert_equal, assert_frame_equal, + assert_identical, raises_regex) + +import pytest + +param = pytest.param +xfail = pytest.mark.xfail + +if not IS_NEP18_ACTIVE: + pytest.skip("NUMPY_EXPERIMENTAL_ARRAY_FUNCTION is not enabled", + allow_module_level=True) + +sparse = pytest.importorskip('sparse') +from sparse.utils import assert_eq as assert_sparse_eq # noqa +from sparse import COO, SparseArray # noqa + + +def make_ndarray(shape): + return np.arange(np.prod(shape)).reshape(shape) + + +def make_sparray(shape): + return sparse.random(shape, density=0.1, random_state=0) + + +def make_xrvar(dim_lengths): + return xr.Variable( + tuple(dim_lengths.keys()), + make_sparray(shape=tuple(dim_lengths.values()))) + + +def make_xrarray(dim_lengths, coords=None, name='test'): + if coords is None: + coords = {d: np.arange(n) for d, n in dim_lengths.items()} + return xr.DataArray( + make_sparray(shape=tuple(dim_lengths.values())), + dims=tuple(coords.keys()), + coords=coords, + name=name) + + +class do: + def __init__(self, meth, *args, **kwargs): + self.meth = meth + self.args = args + self.kwargs = kwargs + + def __call__(self, obj): + return getattr(obj, self.meth)(*self.args, **self.kwargs) + + def __repr__(self): + return 'obj.{}(*{}, **{})'.format(self.meth, self.args, self.kwargs) + + +@pytest.mark.parametrize("prop", [ + 'chunks', + 'data', + 'dims', + 'dtype', + 'encoding', + 'imag', + 'nbytes', + 'ndim', + param('values', marks=xfail(reason='Coercion to dense')) +]) +def test_variable_property(prop): + var = make_xrvar({'x': 10, 'y': 5}) + getattr(var, prop) + + +@pytest.mark.parametrize("func,sparse_output", [ + (do('all'), False), + (do('any'), False), + (do('astype', dtype=int), True), + (do('broadcast_equals', make_xrvar({'x': 10, 'y': 5})), False), + (do('clip', min=0, max=1), True), + (do('coarsen', windows={'x': 2}, func=np.sum), True), + (do('compute'), True), + (do('conj'), True), + (do('copy'), True), + (do('count'), False), + (do('equals', make_xrvar({'x': 10, 'y': 5})), False), + (do('get_axis_num', dim='x'), False), + (do('identical', other=make_xrvar({'x': 10, 'y': 5})), False), + (do('isel', x=slice(2, 4)), True), + (do('isnull'), True), + (do('load'), True), + (do('mean'), False), + (do('notnull'), True), + (do('roll'), True), + (do('round'), True), + (do('set_dims', dims=('x', 'y', 'z')), True), + (do('stack', dimensions={'flat': ('x', 'y')}), True), + (do('to_base_variable'), True), + (do('transpose'), True), + (do('unstack', dimensions={'x': {'x1': 5, 'x2': 2}}), True), + + param(do('argmax'), True, + marks=xfail(reason='Missing implementation for np.argmin')), + param(do('argmin'), True, + marks=xfail(reason='Missing implementation for np.argmax')), + param(do('argsort'), True, + marks=xfail(reason="'COO' object has no attribute 'argsort'")), + param(do('chunk', chunks=(5, 5)), True, + marks=xfail), + param(do('concat', variables=[make_xrvar({'x': 10, 'y': 5}), + make_xrvar({'x': 10, 'y': 5})]), True, + marks=xfail(reason='Coercion to dense')), + param(do('conjugate'), True, + marks=xfail(reason="'COO' object has no attribute 'conjugate'")), + param(do('cumprod'), True, + marks=xfail(reason='Missing implementation for np.nancumprod')), + param(do('cumsum'), True, + marks=xfail(reason='Missing implementation for np.nancumsum')), + param(do('fillna', 0), True, + marks=xfail(reason='Missing implementation for np.result_type')), + param(do('item', (1, 1)), False, + marks=xfail(reason="'COO' object has no attribute 'item'")), + param(do('max'), False, + marks=xfail(reason='Coercion to dense via bottleneck')), + param(do('median'), False, + marks=xfail(reason='Coercion to dense via bottleneck')), + param(do('min'), False, + marks=xfail(reason='Coercion to dense via bottleneck')), + param(do('no_conflicts', other=make_xrvar({'x': 10, 'y': 5})), True, + marks=xfail(reason='mixed sparse-dense operation')), + param(do('pad_with_fill_value', pad_widths={'x': (1, 1)}, fill_value=5), True, # noqa + marks=xfail(reason='Missing implementation for np.pad')), + param(do('prod'), False, + marks=xfail(reason='Missing implementation for np.result_type')), + param(do('quantile', q=0.5), True, + marks=xfail(reason='Missing implementation for np.nanpercentile')), + param(do('rank', dim='x'), False, + marks=xfail(reason='Coercion to dense via bottleneck')), + param(do('reduce', func=np.sum, dim='x'), True, + marks=xfail(reason='Coercion to dense')), + param(do('rolling_window', dim='x', window=2, window_dim='x_win'), True, + marks=xfail(reason='Missing implementation for np.pad')), + param(do('shift', x=2), True, + marks=xfail(reason='mixed sparse-dense operation')), + param(do('std'), False, + marks=xfail(reason='Coercion to dense via bottleneck')), + param(do('sum'), False, + marks=xfail(reason='Missing implementation for np.result_type')), + param(do('var'), False, + marks=xfail(reason='Coercion to dense via bottleneck')), + param(do('to_dict'), False, + marks=xfail(reason='Coercion to dense')), + param(do('where', cond=make_xrvar({'x': 10, 'y': 5}) > 0.5), True, + marks=xfail(reason='Coercion of dense to sparse when using sparse mask')), # noqa +], +ids=repr) +def test_variable_method(func, sparse_output): + var_s = make_xrvar({'x': 10, 'y': 5}) + var_d = xr.Variable(var_s.dims, var_s.data.todense()) + ret_s = func(var_s) + ret_d = func(var_d) + + if sparse_output: + assert isinstance(ret_s.data, SparseArray) + assert np.allclose(ret_s.data.todense(), ret_d.data, equal_nan=True) + else: + assert np.allclose(ret_s, ret_d, equal_nan=True) + + +@pytest.mark.parametrize("func,sparse_output", [ + (do('squeeze'), True), + + param(do('to_index'), False, + marks=xfail(reason='Coercion to dense')), + param(do('to_index_variable'), False, + marks=xfail(reason='Coercion to dense')), + param(do('searchsorted', 0.5), True, + marks=xfail(reason="'COO' object has no attribute 'searchsorted'")), +]) +def test_1d_variable_method(func, sparse_output): + var_s = make_xrvar({'x': 10}) + var_d = xr.Variable(var_s.dims, var_s.data.todense()) + ret_s = func(var_s) + ret_d = func(var_d) + + if sparse_output: + assert isinstance(ret_s.data, SparseArray) + assert np.allclose(ret_s.data.todense(), ret_d.data) + else: + assert np.allclose(ret_s, ret_d) + + +class TestSparseVariable: + @pytest.fixture(autouse=True) + def setUp(self): + self.data = sparse.random((4, 6), random_state=0, density=0.5) + self.var = xr.Variable(('x', 'y'), self.data) + + def test_unary_op(self): + assert_sparse_eq(-self.var.data, -self.data) + assert_sparse_eq(abs(self.var).data, abs(self.data)) + assert_sparse_eq(self.var.round().data, self.data.round()) + + def test_univariate_ufunc(self): + assert_sparse_eq(np.sin(self.data), xu.sin(self.var).data) + + def test_bivariate_ufunc(self): + assert_sparse_eq(np.maximum(self.data, 0), + xu.maximum(self.var, 0).data) + assert_sparse_eq(np.maximum(self.data, 0), + xu.maximum(0, self.var).data) + + def test_repr(self): + expected = dedent("""\ + + """) + assert expected == repr(self.var) + + def test_pickle(self): + v1 = self.var + v2 = pickle.loads(pickle.dumps(v1)) + assert_sparse_eq(v1.data, v2.data) + + @pytest.mark.xfail(reason="Missing implementation for np.result_type") + def test_missing_values(self): + a = np.array([0, 1, np.nan, 3]) + s = COO.from_numpy(a) + var_s = Variable('x', s) + assert np.all(var_s.fillna(2).data.todense() == np.arange(4)) + assert np.all(var_s.count() == 3) + + +@pytest.mark.parametrize("prop", [ + 'attrs', + 'chunks', + 'coords', + 'data', + 'dims', + 'dtype', + 'encoding', + 'imag', + 'indexes', + 'loc', + 'name', + 'nbytes', + 'ndim', + 'plot', + 'real', + 'shape', + 'size', + 'sizes', + 'str', + 'variable', +]) +def test_dataarray_property(prop): + arr = make_xrarray({'x': 10, 'y': 5}) + getattr(arr, prop) + + +@pytest.mark.parametrize("func,sparse_output", [ + (do('all'), False), + (do('any'), False), + (do('assign_attrs', {'foo': 'bar'}), True), + (do('assign_coords', x=make_xrarray({'x': 10}).x + 1), True), + (do('astype', int), True), + (do('broadcast_equals', make_xrarray({'x': 10, 'y': 5})), False), + (do('clip', min=0, max=1), True), + (do('compute'), True), + (do('conj'), True), + (do('copy'), True), + (do('count'), False), + (do('diff', 'x'), True), + (do('drop', 'x'), True), + (do('equals', make_xrarray({'x': 10, 'y': 5})), False), + (do('expand_dims', {'z': 2}, axis=2), True), + (do('get_axis_num', 'x'), False), + (do('get_index', 'x'), False), + (do('identical', make_xrarray({'x': 5, 'y': 5})), False), + (do('integrate', 'x'), True), + (do('isel', {'x': slice(0, 3), 'y': slice(2, 4)}), True), + (do('isnull'), True), + (do('load'), True), + (do('mean'), False), + (do('persist'), True), + (do('reindex', {'x': [1, 2, 3]}), True), + (do('rename', 'foo'), True), + (do('reorder_levels'), True), + (do('reset_coords', drop=True), True), + (do('reset_index', 'x'), True), + (do('round'), True), + (do('sel', x=[0, 1, 2]), True), + (do('shift'), True), + (do('sortby', 'x', ascending=False), True), + (do('stack', z={'x', 'y'}), True), + (do('transpose'), True), + + # TODO + # isel_points + # sel_points + # set_index + # swap_dims + + param(do('argmax'), True, + marks=xfail(reason='Missing implementation for np.argmax')), + param(do('argmin'), True, + marks=xfail(reason='Missing implementation for np.argmin')), + param(do('argsort'), True, + marks=xfail(reason="'COO' object has no attribute 'argsort'")), + param(do('bfill', dim='x'), False, + marks=xfail(reason='Missing implementation for np.flip')), + param(do('chunk', chunks=(5, 5)), False, + marks=xfail(reason='Coercion to dense')), + param(do('combine_first', make_xrarray({'x': 10, 'y': 5})), True, + marks=xfail(reason='mixed sparse-dense operation')), + param(do('conjugate'), False, + marks=xfail(reason="'COO' object has no attribute 'conjugate'")), + param(do('cumprod'), True, + marks=xfail(reason='Missing implementation for np.nancumprod')), + param(do('cumsum'), True, + marks=xfail(reason='Missing implementation for np.nancumsum')), + param(do('differentiate', 'x'), False, + marks=xfail(reason='Missing implementation for np.gradient')), + param(do('dot', make_xrarray({'x': 10, 'y': 5})), True, + marks=xfail(reason='Missing implementation for np.einsum')), + param(do('dropna', 'x'), False, + marks=xfail(reason='Coercion to dense')), + param(do('ffill', 'x'), False, + marks=xfail(reason='Coercion to dense via bottleneck.push')), + param(do('fillna', 0), True, + marks=xfail(reason='Missing implementation for np.result_type')), + param(do('interp', coords={'x': np.arange(10) + 0.5}), True, + marks=xfail(reason='Coercion to dense')), + param(do('interp_like', + make_xrarray({'x': 10, 'y': 5}, + coords={'x': np.arange(10) + 0.5, + 'y': np.arange(5) + 0.5})), True, + marks=xfail(reason='Indexing COO with more than one iterable index')), # noqa + param(do('interpolate_na', 'x'), True, + marks=xfail(reason='Coercion to dense')), + param(do('isin', [1, 2, 3]), False, + marks=xfail(reason='Missing implementation for np.isin')), + param(do('item', (1, 1)), False, + marks=xfail(reason="'COO' object has no attribute 'item'")), + param(do('max'), False, + marks=xfail(reason='Coercion to dense via bottleneck')), + param(do('median'), False, + marks=xfail(reason='Coercion to dense via bottleneck')), + param(do('min'), False, + marks=xfail(reason='Coercion to dense via bottleneck')), + param(do('notnull'), False, + marks=xfail(reason="'COO' object has no attribute 'notnull'")), + param(do('pipe', np.sum, axis=1), True, + marks=xfail(reason='Missing implementation for np.result_type')), + param(do('prod'), False, + marks=xfail(reason='Missing implementation for np.result_type')), + param(do('quantile', q=0.5), False, + marks=xfail(reason='Missing implementation for np.nanpercentile')), + param(do('rank', 'x'), False, + marks=xfail(reason='Coercion to dense via bottleneck')), + param(do('reduce', np.sum, dim='x'), False, + marks=xfail(reason='Coercion to dense')), + param(do('reindex_like', + make_xrarray({'x': 10, 'y': 5}, + coords={'x': np.arange(10) + 0.5, + 'y': np.arange(5) + 0.5})), + True, + marks=xfail(reason='Indexing COO with more than one iterable index')), # noqa + param(do('roll', x=2), True, + marks=xfail(reason='Missing implementation for np.result_type')), + param(do('sel', x=[0, 1, 2], y=[2, 3]), True, + marks=xfail(reason='Indexing COO with more than one iterable index')), # noqa + param(do('std'), False, + marks=xfail(reason='Coercion to dense via bottleneck')), + param(do('sum'), False, + marks=xfail(reason='Missing implementation for np.result_type')), + param(do('var'), False, + marks=xfail(reason='Coercion to dense via bottleneck')), + param(do('where', make_xrarray({'x': 10, 'y': 5}) > 0.5), False, + marks=xfail(reason='Conversion of dense to sparse when using sparse mask')), # noqa +], +ids=repr) +def test_dataarray_method(func, sparse_output): + arr_s = make_xrarray({'x': 10, 'y': 5}, + coords={'x': np.arange(10), 'y': np.arange(5)}) + arr_d = xr.DataArray( + arr_s.data.todense(), + coords=arr_s.coords, + dims=arr_s.dims) + ret_s = func(arr_s) + ret_d = func(arr_d) + + if sparse_output: + assert isinstance(ret_s.data, SparseArray) + assert np.allclose(ret_s.data.todense(), ret_d.data, equal_nan=True) + else: + assert np.allclose(ret_s, ret_d, equal_nan=True) + + +@pytest.mark.parametrize("func,sparse_output", [ + (do('squeeze'), True), + param(do('searchsorted', [1, 2, 3]), False, + marks=xfail(reason="'COO' object has no attribute 'searchsorted'")), +]) +def test_datarray_1d_method(func, sparse_output): + arr_s = make_xrarray({'x': 10}, coords={'x': np.arange(10)}) + arr_d = xr.DataArray( + arr_s.data.todense(), + coords=arr_s.coords, + dims=arr_s.dims) + ret_s = func(arr_s) + ret_d = func(arr_d) + + if sparse_output: + assert isinstance(ret_s.data, SparseArray) + assert np.allclose(ret_s.data.todense(), ret_d.data, equal_nan=True) + else: + assert np.allclose(ret_s, ret_d, equal_nan=True) + + +class TestSparseDataArrayAndDataset: + @pytest.fixture(autouse=True) + def setUp(self): + self.sp_ar = sparse.random((4, 6), random_state=0, density=0.5) + self.sp_xr = xr.DataArray(self.sp_ar, coords={'x': range(4)}, + dims=('x', 'y'), name='foo') + self.ds_ar = self.sp_ar.todense() + self.ds_xr = xr.DataArray(self.ds_ar, coords={'x': range(4)}, + dims=('x', 'y'), name='foo') + + @pytest.mark.xfail(reason='Missing implementation for np.result_type') + def test_to_dataset_roundtrip(self): + x = self.sp_xr + assert_equal(x, x.to_dataset('x').to_array('x')) + + def test_align(self): + a1 = xr.DataArray( + COO.from_numpy(np.arange(4)), + dims=['x'], + coords={'x': ['a', 'b', 'c', 'd']}) + b1 = xr.DataArray( + COO.from_numpy(np.arange(4)), + dims=['x'], + coords={'x': ['a', 'b', 'd', 'e']}) + a2, b2 = xr.align(a1, b1, join='inner') + assert isinstance(a2.data, sparse.SparseArray) + assert isinstance(b2.data, sparse.SparseArray) + assert np.all(a2.coords['x'].data == ['a', 'b', 'd']) + assert np.all(b2.coords['x'].data == ['a', 'b', 'd']) + + @pytest.mark.xfail( + reason="COO objects currently do not accept more than one " + "iterable index at a time") + def test_align_2d(self): + A1 = xr.DataArray(self.sp_ar, dims=['x', 'y'], coords={ + 'x': np.arange(self.sp_ar.shape[0]), + 'y': np.arange(self.sp_ar.shape[1]) + }) + + A2 = xr.DataArray(self.sp_ar, dims=['x', 'y'], coords={ + 'x': np.arange(1, self.sp_ar.shape[0] + 1), + 'y': np.arange(1, self.sp_ar.shape[1] + 1) + }) + + B1, B2 = xr.align(A1, A2, join='inner') + assert np.all(B1.coords['x'] == np.arange(1, self.sp_ar.shape[0])) + assert np.all(B1.coords['y'] == np.arange(1, self.sp_ar.shape[0])) + assert np.all(B1.coords['x'] == B2.coords['x']) + assert np.all(B1.coords['y'] == B2.coords['y']) + + @pytest.mark.xfail(reason="fill value leads to sparse-dense operation") + def test_align_outer(self): + a1 = xr.DataArray( + COO.from_numpy(np.arange(4)), + dims=['x'], + coords={'x': ['a', 'b', 'c', 'd']}) + b1 = xr.DataArray( + COO.from_numpy(np.arange(4)), + dims=['x'], + coords={'x': ['a', 'b', 'd', 'e']}) + a2, b2 = xr.align(a1, b1, join='outer') + assert isinstance(a2.data, sparse.SparseArray) + assert isinstance(b2.data, sparse.SparseArray) + assert np.all(a2.coords['x'].data == ['a', 'b', 'c', 'd']) + assert np.all(b2.coords['x'].data == ['a', 'b', 'c', 'd']) + + @pytest.mark.xfail(reason='Missing implementation for np.result_type') + def test_concat(self): + ds1 = xr.Dataset(data_vars={'d': self.sp_xr}) + ds2 = xr.Dataset(data_vars={'d': self.sp_xr}) + ds3 = xr.Dataset(data_vars={'d': self.sp_xr}) + out = xr.concat([ds1, ds2, ds3], dim='x') + assert_sparse_eq( + out['d'].data, + sparse.concatenate([self.sp_ar, self.sp_ar, self.sp_ar], axis=0) + ) + + out = xr.concat([self.sp_xr, self.sp_xr, self.sp_xr], dim='y') + assert_sparse_eq( + out.data, + sparse.concatenate([self.sp_ar, self.sp_ar, self.sp_ar], axis=1) + ) + + def test_stack(self): + arr = make_xrarray({'w': 2, 'x': 3, 'y': 4}) + stacked = arr.stack(z=('x', 'y')) + + z = pd.MultiIndex.from_product( + [np.arange(3), np.arange(4)], + names=['x', 'y']) + + expected = xr.DataArray( + arr.data.reshape((2, -1)), + {'w': [0, 1], 'z': z}, + dims=['w', 'z']) + + assert_equal(expected, stacked) + + roundtripped = stacked.unstack() + assert arr.identical(roundtripped) + + def test_ufuncs(self): + x = self.sp_xr + assert_equal(np.sin(x), xu.sin(x)) + + def test_dataarray_repr(self): + a = xr.DataArray( + COO.from_numpy(np.ones((4))), + dims=['x'], + coords={'y': ('x', COO.from_numpy(np.arange(4)))}) + expected = dedent("""\ + + + Coordinates: + y (x) int64 ... + Dimensions without coordinates: x""") + assert expected == repr(a) + + def test_dataset_repr(self): + ds = xr.Dataset( + data_vars={'a': ('x', COO.from_numpy(np.ones((4))))}, + coords={'y': ('x', COO.from_numpy(np.arange(4)))}) + expected = dedent("""\ + + Dimensions: (x: 4) + Coordinates: + y (x) int64 ... + Dimensions without coordinates: x + Data variables: + a (x) float64 ...""") + assert expected == repr(ds) + + def test_dataarray_pickle(self): + a1 = xr.DataArray( + COO.from_numpy(np.ones((4))), + dims=['x'], + coords={'y': ('x', COO.from_numpy(np.arange(4)))}) + a2 = pickle.loads(pickle.dumps(a1)) + assert_identical(a1, a2) + + def test_dataset_pickle(self): + ds1 = xr.Dataset( + data_vars={'a': ('x', COO.from_numpy(np.ones((4))))}, + coords={'y': ('x', COO.from_numpy(np.arange(4)))}) + ds2 = pickle.loads(pickle.dumps(ds1)) + assert_identical(ds1, ds2) + + def test_coarsen(self): + a1 = self.ds_xr + a2 = self.sp_xr + m1 = a1.coarsen(x=2, boundary='trim').mean() + m2 = a2.coarsen(x=2, boundary='trim').mean() + + assert isinstance(m2.data, sparse.SparseArray) + assert np.allclose(m1.data, m2.data.todense()) + + @pytest.mark.xfail(reason="No implementation of np.pad") + def test_rolling(self): + a1 = self.ds_xr + a2 = self.sp_xr + m1 = a1.rolling(x=2, center=True).mean() + m2 = a2.rolling(x=2, center=True).mean() + + assert isinstance(m2.data, sparse.SparseArray) + assert np.allclose(m1.data, m2.data.todense()) + + @pytest.mark.xfail(reason="Coercion to dense") + def test_rolling_exp(self): + a1 = self.ds_xr + a2 = self.sp_xr + m1 = a1.rolling_exp(x=2, center=True).mean() + m2 = a2.rolling_exp(x=2, center=True).mean() + + assert isinstance(m2.data, sparse.SparseArray) + assert np.allclose(m1.data, m2.data.todense()) + + @pytest.mark.xfail(reason="No implementation of np.einsum") + def test_dot(self): + a1 = self.xp_xr.dot(self.xp_xr[0]) + a2 = self.sp_ar.dot(self.sp_ar[0]) + assert_equal(a1, a2) + + @pytest.mark.xfail(reason="Groupby reductions produce dense output") + def test_groupby(self): + x1 = self.ds_xr + x2 = self.sp_xr + m1 = x1.groupby('x').mean(xr.ALL_DIMS) + m2 = x2.groupby('x').mean(xr.ALL_DIMS) + assert isinstance(m2.data, sparse.SparseArray) + assert np.allclose(m1.data, m2.data.todense()) + + @pytest.mark.xfail(reason="Groupby reductions produce dense output") + def test_groupby_first(self): + x = self.sp_xr.copy() + x.coords['ab'] = ('x', ['a', 'a', 'b', 'b']) + x.groupby('ab').first() + x.groupby('ab').first(skipna=False) + + @pytest.mark.xfail(reason="Groupby reductions produce dense output") + def test_groupby_bins(self): + x1 = self.ds_xr + x2 = self.sp_xr + m1 = x1.groupby_bins('x', bins=[0, 3, 7, 10]).sum() + m2 = x2.groupby_bins('x', bins=[0, 3, 7, 10]).sum() + assert isinstance(m2.data, sparse.SparseArray) + assert np.allclose(m1.data, m2.data.todense()) + + @pytest.mark.xfail(reason="Resample produces dense output") + def test_resample(self): + t1 = xr.DataArray(np.linspace(0, 11, num=12), + coords=[pd.date_range('15/12/1999', + periods=12, freq=pd.DateOffset(months=1))], + dims='time') + t2 = t1.copy() + t2.data = COO(t2.data) + m1 = t1.resample(time="QS-DEC").mean() + m2 = t2.resample(time="QS-DEC").mean() + assert isinstance(m2.data, sparse.SparseArray) + assert np.allclose(m1.data, m2.data.todense()) + + @pytest.mark.xfail + def test_reindex(self): + x1 = self.ds_xr + x2 = self.sp_xr + for kwargs in [{'x': [2, 3, 4]}, + {'x': [1, 100, 2, 101, 3]}, + {'x': [2.5, 3, 3.5], 'y': [2, 2.5, 3]}]: + m1 = x1.reindex(**kwargs) + m2 = x2.reindex(**kwargs) + assert np.allclose(m1, m2, equal_nan=True) + + @pytest.mark.xfail + def test_merge(self): + x = self.sp_xr + y = xr.merge([x, x.rename('bar')]).to_array() + assert isinstance(y, sparse.SparseArray) + + @pytest.mark.xfail + def test_where(self): + a = np.arange(10) + cond = a > 3 + xr.DataArray(a).where(cond) + + s = COO.from_numpy(a) + cond = s > 3 + xr.DataArray(s).where(cond) + + x = xr.DataArray(s) + cond = x > 3 + x.where(cond) + + +class TestSparseCoords: + @pytest.mark.xfail(reason="Coercion of coords to dense") + def test_sparse_coords(self): + xr.DataArray( + COO.from_numpy(np.arange(4)), + dims=['x'], + coords={'x': COO.from_numpy([1, 2, 3, 4])}) From d1935ffd24fd83e9b5a1347aed11773da03c878e Mon Sep 17 00:00:00 2001 From: crusaderky Date: Tue, 6 Aug 2019 02:19:35 +0100 Subject: [PATCH 05/10] More annotations (#3177) * Annotations for Dataset.drop et al * Annotations for Dataset.interpolate et al * Dataset.drop(DataArray) * flake8 * trivial * @overload Dataset.drop * docstring tweaks * Clean up redundant code --- xarray/core/common.py | 9 +- xarray/core/dataarray.py | 35 ++++-- xarray/core/dataset.py | 193 +++++++++++++++++++++++---------- xarray/tests/test_dataarray.py | 4 +- xarray/tests/test_dataset.py | 6 + 5 files changed, 178 insertions(+), 69 deletions(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index bae3b6cd73d..93a5bb71b07 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -20,6 +20,7 @@ ALL_DIMS = ReprObject('') +C = TypeVar('C') T = TypeVar('T') @@ -297,9 +298,11 @@ def get_index(self, key: Hashable) -> pd.Index: # need to ensure dtype=int64 in case range is empty on Python 2 return pd.Index(range(self.sizes[key]), name=key, dtype=np.int64) - def _calc_assign_results(self, kwargs: Mapping[str, T] - ) -> MutableMapping[str, T]: - results = SortedKeysDict() # type: SortedKeysDict[str, T] + def _calc_assign_results( + self: C, + kwargs: Mapping[Hashable, Union[T, Callable[[C], T]]] + ) -> MutableMapping[Hashable, T]: + results = SortedKeysDict() # type: SortedKeysDict[Hashable, T] for k, v in kwargs.items(): if callable(v): results[k] = v(self) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index ba6477f34cc..19d595079e5 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -4,7 +4,8 @@ from collections import OrderedDict from numbers import Number from typing import (Any, Callable, Dict, Hashable, Iterable, List, Mapping, - Optional, Sequence, Tuple, Union, cast, TYPE_CHECKING) + Optional, Sequence, Tuple, Union, cast, overload, + TYPE_CHECKING) import numpy as np import pandas as pd @@ -1755,17 +1756,35 @@ def transpose(self, def T(self) -> 'DataArray': return self.transpose() - def drop(self, - labels: Union[Hashable, Sequence[Hashable]], - dim: Hashable = None, - *, - errors: str = 'raise') -> 'DataArray': + # Drop coords + @overload + def drop( + self, + labels: Union[Hashable, Iterable[Hashable]], + *, + errors: str = 'raise' + ) -> 'DataArray': + ... + + # Drop index labels along dimension + @overload # noqa: F811 + def drop( + self, + labels: Any, # array-like + dim: Hashable, + *, + errors: str = 'raise' + ) -> 'DataArray': + ... + + def drop(self, labels, dim=None, *, errors='raise'): # noqa: F811 """Drop coordinates or index labels from this DataArray. Parameters ---------- labels : hashable or sequence of hashables - Name(s) of coordinate variables or index labels to drop. + Name(s) of coordinates or index labels to drop. + If dim is not None, labels can be any array-like. dim : hashable, optional Dimension along which to drop index labels. By default (if ``dim is None``), drops coordinates rather than index labels. @@ -1778,8 +1797,6 @@ def drop(self, ------- dropped : DataArray """ - if utils.is_scalar(labels): - labels = [labels] ds = self._to_temp_dataset().drop(labels, dim, errors=errors) return self._from_temp_dataset(ds) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index b00dad965ed..5d3ca932ccc 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -6,9 +6,10 @@ from distutils.version import LooseVersion from numbers import Number from pathlib import Path -from typing import (Any, DefaultDict, Dict, Hashable, Iterable, Iterator, List, - Mapping, MutableMapping, Optional, Sequence, Set, Tuple, - Union, cast, TYPE_CHECKING) +from typing import (Any, Callable, DefaultDict, Dict, Hashable, Iterable, + Iterator, List, Mapping, MutableMapping, Optional, + Sequence, Set, Tuple, Union, cast, overload, + TYPE_CHECKING) import numpy as np import pandas as pd @@ -315,10 +316,10 @@ class _LocIndexer: def __init__(self, dataset: 'Dataset'): self.dataset = dataset - def __getitem__(self, key: Mapping[str, Any]) -> 'Dataset': + def __getitem__(self, key: Mapping[Hashable, Any]) -> 'Dataset': if not utils.is_dict_like(key): raise TypeError('can only lookup dictionaries from Dataset.loc') - return self.dataset.sel(**key) + return self.dataset.sel(key) class Dataset(Mapping, ImplementsDatasetReduce, DataWithCoords): @@ -792,7 +793,7 @@ def _replace_with_new_dims( # type: ignore self, variables: 'OrderedDict[Any, Variable]', coord_names: set = None, - attrs: 'OrderedDict' = __default, + attrs: Optional['OrderedDict'] = __default, indexes: 'OrderedDict[Any, pd.Index]' = __default, inplace: bool = False, ) -> 'Dataset': @@ -3261,7 +3262,8 @@ def merge( return self._replace_vars_and_dims(variables, coord_names, dims, inplace=inplace) - def _assert_all_in_dataset(self, names, virtual_okay=False): + def _assert_all_in_dataset(self, names: Iterable[Hashable], + virtual_okay: bool = False) -> None: bad_names = set(names) - set(self._variables) if virtual_okay: bad_names -= self.virtual_variables @@ -3269,14 +3271,36 @@ def _assert_all_in_dataset(self, names, virtual_okay=False): raise ValueError('One or more of the specified variables ' 'cannot be found in this dataset') - def drop(self, labels, dim=None, *, errors='raise'): + # Drop variables + @overload + def drop( + self, + labels: Union[Hashable, Iterable[Hashable]], + *, + errors: str = 'raise' + ) -> 'Dataset': + ... + + # Drop index labels along dimension + @overload # noqa: F811 + def drop( + self, + labels: Any, # array-like + dim: Hashable, + *, + errors: str = 'raise' + ) -> 'Dataset': + ... + + def drop(self, labels, dim=None, *, errors='raise'): # noqa: F811 """Drop variables or index labels from this dataset. Parameters ---------- - labels : scalar or list of scalars + labels : hashable or iterable of hashables Name(s) of variables or index labels to drop. - dim : None or str, optional + If dim is not None, labels can be any array-like. + dim : None or hashable, optional Dimension along which to drop index labels. By default (if ``dim is None``), drops variables rather than index labels. errors: {'raise', 'ignore'}, optional @@ -3291,11 +3315,21 @@ def drop(self, labels, dim=None, *, errors='raise'): """ if errors not in ['raise', 'ignore']: raise ValueError('errors must be either "raise" or "ignore"') - if utils.is_scalar(labels): - labels = [labels] + if dim is None: + if isinstance(labels, str) or not isinstance(labels, Iterable): + labels = {labels} + else: + labels = set(labels) + return self._drop_vars(labels, errors=errors) else: + # Don't cast to set, as it would harm performance when labels + # is a large numpy array + if utils.is_scalar(labels): + labels = [labels] + labels = np.asarray(labels) + try: index = self.indexes[dim] except KeyError: @@ -3304,25 +3338,38 @@ def drop(self, labels, dim=None, *, errors='raise'): new_index = index.drop(labels, errors=errors) return self.loc[{dim: new_index}] - def _drop_vars(self, names, errors='raise'): + def _drop_vars( + self, + names: set, + errors: str = 'raise' + ) -> 'Dataset': if errors == 'raise': self._assert_all_in_dataset(names) - drop = set(names) + variables = OrderedDict((k, v) for k, v in self._variables.items() - if k not in drop) + if k not in names) coord_names = set(k for k in self._coord_names if k in variables) indexes = OrderedDict((k, v) for k, v in self.indexes.items() - if k not in drop) + if k not in names) return self._replace_with_new_dims( variables, coord_names=coord_names, indexes=indexes) - def drop_dims(self, drop_dims, *, errors='raise'): + def drop_dims( + self, + drop_dims: Union[Hashable, Iterable[Hashable]], + *, + errors: str = 'raise' + ) -> 'Dataset': """Drop dimensions and associated variables from this dataset. Parameters ---------- drop_dims : str or list Dimension or dimensions to drop. + errors: {'raise', 'ignore'}, optional + If 'raise' (default), raises a ValueError error if any of the + dimensions passed are not in the dataset. If 'ignore', any given + labels that are in the dataset are dropped and no error is raised. Returns ------- @@ -3338,8 +3385,10 @@ def drop_dims(self, drop_dims, *, errors='raise'): if errors not in ['raise', 'ignore']: raise ValueError('errors must be either "raise" or "ignore"') - if utils.is_scalar(drop_dims): + if isinstance(drop_dims, str) or not isinstance(drop_dims, Iterable): drop_dims = [drop_dims] + else: + drop_dims = list(drop_dims) if errors == 'raise': missing_dimensions = [d for d in drop_dims if d not in self.dims] @@ -3351,7 +3400,7 @@ def drop_dims(self, drop_dims, *, errors='raise'): for d in v.dims if d in drop_dims) return self._drop_vars(drop_vars) - def transpose(self, *dims): + def transpose(self, *dims: Hashable) -> 'Dataset': """Return a new Dataset object with all array dimensions transposed. Although the order of dimensions on each array will change, the dataset @@ -3359,7 +3408,7 @@ def transpose(self, *dims): Parameters ---------- - *dims : str, optional + *dims : Hashable, optional By default, reverse the dimensions on each array. Otherwise, reorder the dimensions to this order. @@ -3391,13 +3440,19 @@ def transpose(self, *dims): ds._variables[name] = var.transpose(*var_dims) return ds - def dropna(self, dim, how='any', thresh=None, subset=None): + def dropna( + self, + dim: Hashable, + how: str = 'any', + thresh: int = None, + subset: Iterable[Hashable] = None + ): """Returns a new dataset with dropped labels for missing values along the provided dimension. Parameters ---------- - dim : str + dim : Hashable Dimension along which to drop missing values. Dropping along multiple dimensions simultaneously is not yet supported. how : {'any', 'all'}, optional @@ -3405,8 +3460,8 @@ def dropna(self, dim, how='any', thresh=None, subset=None): * all : if all values are NA, drop that label thresh : int, default None If supplied, require this many non-NA values. - subset : sequence, optional - Subset of variables to check for missing values. By default, all + subset : iterable of hashable, optional + Which variables to check for missing values. By default, all variables in the dataset are checked. Returns @@ -3421,7 +3476,7 @@ def dropna(self, dim, how='any', thresh=None, subset=None): raise ValueError('%s must be a single dataset dimension' % dim) if subset is None: - subset = list(self.data_vars) + subset = iter(self.data_vars) count = np.zeros(self.dims[dim], dtype=np.int64) size = 0 @@ -3430,7 +3485,7 @@ def dropna(self, dim, how='any', thresh=None, subset=None): array = self._variables[k] if dim in array.dims: dims = [d for d in array.dims if d != dim] - count += np.asarray(array.count(dims)) + count += np.asarray(array.count(dims)) # type: ignore size += np.prod([self.dims[d] for d in dims]) if thresh is not None: @@ -3446,7 +3501,7 @@ def dropna(self, dim, how='any', thresh=None, subset=None): return self.isel({dim: mask}) - def fillna(self, value): + def fillna(self, value: Any) -> 'Dataset': """Fill missing values in this object. This operation follows the normal broadcasting and alignment rules that @@ -3475,14 +3530,19 @@ def fillna(self, value): out = ops.fillna(self, value) return out - def interpolate_na(self, dim=None, method='linear', limit=None, - use_coordinate=True, - **kwargs): + def interpolate_na( + self, + dim: Hashable = None, + method: str = 'linear', + limit: int = None, + use_coordinate: Union[bool, Hashable] = True, + **kwargs: Any + ) -> 'Dataset': """Interpolate values according to different methods. Parameters ---------- - dim : str + dim : Hashable Specifies the dimension along which to interpolate. method : {'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic', 'polynomial', 'barycentric', 'krog', 'pchip', @@ -3506,6 +3566,8 @@ def interpolate_na(self, dim=None, method='linear', limit=None, limit : int, default None Maximum number of consecutive NaNs to fill. Must be greater than 0 or None for no limit. + kwargs : any + parameters passed verbatim to the underlying interplation function Returns ------- @@ -3524,14 +3586,14 @@ def interpolate_na(self, dim=None, method='linear', limit=None, **kwargs) return new - def ffill(self, dim, limit=None): - '''Fill NaN values by propogating values forward + def ffill(self, dim: Hashable, limit: int = None) -> 'Dataset': + """Fill NaN values by propogating values forward *Requires bottleneck.* Parameters ---------- - dim : str + dim : Hashable Specifies the dimension along which to propagate values when filling. limit : int, default None @@ -3543,14 +3605,14 @@ def ffill(self, dim, limit=None): Returns ------- Dataset - ''' + """ from .missing import ffill, _apply_over_vars_with_dim new = _apply_over_vars_with_dim(ffill, self, dim=dim, limit=limit) return new - def bfill(self, dim, limit=None): - '''Fill NaN values by propogating values backward + def bfill(self, dim: Hashable, limit: int = None) -> 'Dataset': + """Fill NaN values by propogating values backward *Requires bottleneck.* @@ -3568,13 +3630,13 @@ def bfill(self, dim, limit=None): Returns ------- Dataset - ''' + """ from .missing import bfill, _apply_over_vars_with_dim new = _apply_over_vars_with_dim(bfill, self, dim=dim, limit=limit) return new - def combine_first(self, other): + def combine_first(self, other: 'Dataset') -> 'Dataset': """Combine two Datasets, default to data_vars of self. The new coordinates follow the normal broadcasting and alignment rules @@ -3583,7 +3645,7 @@ def combine_first(self, other): Parameters ---------- - other : DataArray + other : Dataset Used to fill all matching missing values in this array. Returns @@ -3593,13 +3655,21 @@ def combine_first(self, other): out = ops.fillna(self, other, join="outer", dataset_join="outer") return out - def reduce(self, func, dim=None, keep_attrs=None, keepdims=False, - numeric_only=False, allow_lazy=False, **kwargs): + def reduce( + self, + func: Callable, + dim: Union[Hashable, Iterable[Hashable]] = None, + keep_attrs: bool = None, + keepdims: bool = False, + numeric_only: bool = False, + allow_lazy: bool = False, + **kwargs: Any + ) -> 'Dataset': """Reduce this dataset by applying `func` along some dimension(s). Parameters ---------- - func : function + func : callable Function which can be called in the form `f(x, axis=axis, **kwargs)` to return the result of reducing an np.ndarray over an integer valued axis. @@ -3616,7 +3686,7 @@ def reduce(self, func, dim=None, keep_attrs=None, keepdims=False, are removed. numeric_only : bool, optional If True, only apply ``func`` to variables with a numeric dtype. - **kwargs : dict + **kwargs : Any Additional keyword arguments passed on to ``func``. Returns @@ -3627,10 +3697,10 @@ def reduce(self, func, dim=None, keep_attrs=None, keepdims=False, """ if dim is ALL_DIMS: dim = None - if isinstance(dim, str): - dims = set([dim]) - elif dim is None: + if dim is None: dims = set(self.dims) + elif isinstance(dim, str) or not isinstance(dim, Iterable): + dims = {dim} else: dims = set(dim) @@ -3642,9 +3712,12 @@ def reduce(self, func, dim=None, keep_attrs=None, keepdims=False, if keep_attrs is None: keep_attrs = _get_keep_attrs(default=False) - variables = OrderedDict() + variables = OrderedDict() # type: OrderedDict[Hashable, Variable] for name, var in self._variables.items(): - reduce_dims = [d for d in var.dims if d in dims] + reduce_dims = [ + d for d in var.dims + if d in dims + ] if name in self.coords: if not reduce_dims: variables[name] = var @@ -3660,7 +3733,7 @@ def reduce(self, func, dim=None, keep_attrs=None, keepdims=False, # prefer to aggregate over axis=None rather than # axis=(0, 1) if they will be equivalent, because # the former is often more efficient - reduce_dims = None + reduce_dims = None # type: ignore variables[name] = var.reduce(func, dim=reduce_dims, keep_attrs=keep_attrs, keepdims=keepdims, @@ -3674,12 +3747,18 @@ def reduce(self, func, dim=None, keep_attrs=None, keepdims=False, return self._replace_with_new_dims( variables, coord_names=coord_names, attrs=attrs, indexes=indexes) - def apply(self, func, keep_attrs=None, args=(), **kwargs): + def apply( + self, + func: Callable, + keep_attrs: bool = None, + args: Iterable[Any] = (), + **kwargs: Any + ) -> 'Dataset': """Apply a function over the data variables in this dataset. Parameters ---------- - func : function + func : callable Function which can be called in the form `func(x, *args, **kwargs)` to transform each DataArray `x` in this dataset into another DataArray. @@ -3689,7 +3768,7 @@ def apply(self, func, keep_attrs=None, args=(), **kwargs): be returned without attributes. args : tuple, optional Positional arguments passed on to `func`. - **kwargs : dict + **kwargs : Any Keyword arguments passed on to `func`. Returns @@ -3724,7 +3803,11 @@ def apply(self, func, keep_attrs=None, args=(), **kwargs): attrs = self.attrs if keep_attrs else None return type(self)(variables, attrs=attrs) - def assign(self, variables=None, **variables_kwargs): + def assign( + self, + variables: Mapping[Hashable, Any] = None, + **variables_kwargs: Hashable + ) -> 'Dataset': """Assign new data variables to a Dataset, returning a new object with all the original variables in addition to the new ones. @@ -3737,7 +3820,7 @@ def assign(self, variables=None, **variables_kwargs): scalar, or array), they are simply assigned. **variables_kwargs: The keyword arguments form of ``variables``. - One of variables or variables_kwarg must be provided. + One of variables or variables_kwargs must be provided. Returns ------- diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 5697704bdbc..000469f24bf 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -1904,9 +1904,9 @@ def test_drop_coordinates(self): assert_identical(actual, expected) with raises_regex(ValueError, 'cannot be found'): - arr.drop(None) + arr.drop('w') - actual = expected.drop(None, errors='ignore') + actual = expected.drop('w', errors='ignore') assert_identical(actual, expected) renamed = arr.rename('foo') diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index fc6f7f36938..fc15393f269 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -2000,6 +2000,12 @@ def test_drop_index_labels(self): expected = data.isel(x=slice(0, 0)) assert_identical(expected, actual) + # DataArrays as labels are a nasty corner case as they are not + # Iterable[Hashable] - DataArray.__iter__ yields scalar DataArrays. + actual = data.drop(DataArray(['a', 'b', 'c']), 'x', errors='ignore') + expected = data.isel(x=slice(0, 0)) + assert_identical(expected, actual) + with raises_regex( ValueError, 'does not have coordinate labels'): data.drop(1, 'y') From c34685a34dcd74b18f29ba24fd469ff310016400 Mon Sep 17 00:00:00 2001 From: Andrew Barna Date: Mon, 5 Aug 2019 18:20:15 -0700 Subject: [PATCH 06/10] bump rasterio to 1.0.24 in doc building environment (#3186) --- doc/environment.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/environment.yml b/doc/environment.yml index b2f89bd9f96..85161bf317f 100644 --- a/doc/environment.yml +++ b/doc/environment.yml @@ -12,7 +12,7 @@ dependencies: - ipython=7.2.0 - netCDF4=1.4.2 - cartopy=0.17.0 - - rasterio=1.0.13 + - rasterio=1.0.24 - zarr=2.2.0 - iris=2.2.0 - flake8=3.6.0 From 55593a8bcaf2edb79034507990eac9c55b41a07d Mon Sep 17 00:00:00 2001 From: Riley Brady Date: Tue, 6 Aug 2019 12:02:20 -0400 Subject: [PATCH 07/10] add climpred to related-projects (#3188) --- doc/related-projects.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/related-projects.rst b/doc/related-projects.rst index 3e6a745daa0..58b9a7c22c9 100644 --- a/doc/related-projects.rst +++ b/doc/related-projects.rst @@ -11,6 +11,7 @@ Geosciences ~~~~~~~~~~~ - `aospy `_: Automated analysis and management of gridded climate data. +- `climpred `_: Analysis of ensemble forecast models for climate prediction. - `infinite-diff `_: xarray-based finite-differencing, focused on gridded climate/meterology data - `marc_analysis `_: Analysis package for CESM/MARC experiments and output. - `MetPy `_: A collection of tools in Python for reading, visualizing, and performing calculations with weather data. From 1ab7569561db50eaccbae977b0ef69993e0c0d0c Mon Sep 17 00:00:00 2001 From: Andrew Barna Date: Tue, 6 Aug 2019 13:41:37 -0700 Subject: [PATCH 08/10] reduce the size of example dataset in dask docs (#3187) --- doc/dask.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/doc/dask.rst b/doc/dask.rst index ba75eea74cc..b0ffd0c449d 100644 --- a/doc/dask.rst +++ b/doc/dask.rst @@ -58,9 +58,9 @@ argument to :py:func:`~xarray.open_dataset` or using the np.set_printoptions(precision=3, linewidth=100, threshold=100, edgeitems=3) ds = xr.Dataset({'temperature': (('time', 'latitude', 'longitude'), - np.random.randn(365, 180, 360)), - 'time': pd.date_range('2015-01-01', periods=365), - 'longitude': np.arange(360), + np.random.randn(30, 180, 180)), + 'time': pd.date_range('2015-01-01', periods=30), + 'longitude': np.arange(180), 'latitude': np.arange(89.5, -90.5, -1)}) ds.to_netcdf('example-data.nc') From 04597a8dfe134f57b5c4698ec72e4f368200f187 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 7 Aug 2019 12:17:06 +0000 Subject: [PATCH 09/10] mfdataset, concat now support the 'join' kwarg. (#3102) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * mfdatset, concat now support the 'join' kwarg. Closes #1354 * Add whats-new.rst * Add concat tests * doc improvements. * update todo. * mfdataset tests. * manual_combine → combine_nested * Add tests for combine_nested & combine_coords * Update docstring. * lint. --- doc/whats-new.rst | 2 + xarray/backends/api.py | 20 +++++++-- xarray/core/combine.py | 76 ++++++++++++++++++++++++++--------- xarray/core/concat.py | 27 +++++++++---- xarray/core/merge.py | 9 ++++- xarray/tests/test_backends.py | 35 +++++++++++----- xarray/tests/test_combine.py | 56 ++++++++++++++++++++++---- xarray/tests/test_concat.py | 52 ++++++++++++++++++++++++ 8 files changed, 225 insertions(+), 52 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 7de6794d5b4..40c1bbbcaf6 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -40,6 +40,8 @@ New functions/methods Enhancements ~~~~~~~~~~~~ +- :py:func:`~xarray.concat` and :py:func:`~xarray.open_mfdataset` now support the ``join`` kwarg. + It is passed down to :py:func:`~xarray.align`. By `Deepak Cherian `_. - In :py:meth:`~xarray.Dataset.to_zarr`, passing ``mode`` is not mandatory if ``append_dim`` is set, as it will automatically be set to ``'a'`` internally. By `David Brochart `_. diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 21d91a886af..2535c7118a5 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -609,7 +609,7 @@ def open_mfdataset(paths, chunks=None, concat_dim='_not_supplied', compat='no_conflicts', preprocess=None, engine=None, lock=None, data_vars='all', coords='different', combine='_old_auto', autoclose=None, parallel=False, - **kwargs): + join='outer', **kwargs): """Open multiple files as a single dataset. If combine='by_coords' then the function ``combine_by_coords`` is used to @@ -704,6 +704,16 @@ def open_mfdataset(paths, chunks=None, concat_dim='_not_supplied', parallel : bool, optional If True, the open and preprocess steps of this function will be performed in parallel using ``dask.delayed``. Default is False. + join : {'outer', 'inner', 'left', 'right', 'exact'}, optional + String indicating how to combine differing indexes + (excluding concat_dim) in objects + + - 'outer': use the union of object indexes + - 'inner': use the intersection of object indexes + - 'left': use indexes from the first object with each dimension + - 'right': use indexes from the last object with each dimension + - 'exact': instead of aligning, raise `ValueError` when indexes to be + aligned are not equal **kwargs : optional Additional arguments passed on to :py:func:`xarray.open_dataset`. @@ -798,18 +808,20 @@ def open_mfdataset(paths, chunks=None, concat_dim='_not_supplied', combined = auto_combine(datasets, concat_dim=concat_dim, compat=compat, data_vars=data_vars, - coords=coords, from_openmfds=True) + coords=coords, join=join, + from_openmfds=True) elif combine == 'nested': # Combined nested list by successive concat and merge operations # along each dimension, using structure given by "ids" combined = _nested_combine(datasets, concat_dims=concat_dim, compat=compat, data_vars=data_vars, - coords=coords, ids=ids) + coords=coords, ids=ids, join=join) elif combine == 'by_coords': # Redo ordering from coordinates, ignoring how they were ordered # previously combined = combine_by_coords(datasets, compat=compat, - data_vars=data_vars, coords=coords) + data_vars=data_vars, coords=coords, + join=join) else: raise ValueError("{} is an invalid option for the keyword argument" " ``combine``".format(combine)) diff --git a/xarray/core/combine.py b/xarray/core/combine.py index 5718698f852..71da4e4e094 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -136,7 +136,7 @@ def _check_shape_tile_ids(combined_tile_ids): def _combine_nd(combined_ids, concat_dims, data_vars='all', coords='different', compat='no_conflicts', - fill_value=dtypes.NA): + fill_value=dtypes.NA, join='outer'): """ Combines an N-dimensional structure of datasets into one by applying a series of either concat and merge operations along each dimension. @@ -177,13 +177,14 @@ def _combine_nd(combined_ids, concat_dims, data_vars='all', data_vars=data_vars, coords=coords, compat=compat, - fill_value=fill_value) + fill_value=fill_value, + join=join) (combined_ds,) = combined_ids.values() return combined_ds def _combine_all_along_first_dim(combined_ids, dim, data_vars, coords, compat, - fill_value=dtypes.NA): + fill_value=dtypes.NA, join='outer'): # Group into lines of datasets which must be combined along dim # need to sort by _new_tile_id first for groupby to work @@ -197,12 +198,13 @@ def _combine_all_along_first_dim(combined_ids, dim, data_vars, coords, compat, combined_ids = OrderedDict(sorted(group)) datasets = combined_ids.values() new_combined_ids[new_id] = _combine_1d(datasets, dim, compat, - data_vars, coords, fill_value) + data_vars, coords, fill_value, + join) return new_combined_ids def _combine_1d(datasets, concat_dim, compat='no_conflicts', data_vars='all', - coords='different', fill_value=dtypes.NA): + coords='different', fill_value=dtypes.NA, join='outer'): """ Applies either concat or merge to 1D list of datasets depending on value of concat_dim @@ -211,7 +213,7 @@ def _combine_1d(datasets, concat_dim, compat='no_conflicts', data_vars='all', if concat_dim is not None: try: combined = concat(datasets, dim=concat_dim, data_vars=data_vars, - coords=coords, fill_value=fill_value) + coords=coords, fill_value=fill_value, join=join) except ValueError as err: if "encountered unexpected variable" in str(err): raise ValueError("These objects cannot be combined using only " @@ -222,7 +224,8 @@ def _combine_1d(datasets, concat_dim, compat='no_conflicts', data_vars='all', else: raise else: - combined = merge(datasets, compat=compat, fill_value=fill_value) + combined = merge(datasets, compat=compat, fill_value=fill_value, + join=join) return combined @@ -233,7 +236,7 @@ def _new_tile_id(single_id_ds_pair): def _nested_combine(datasets, concat_dims, compat, data_vars, coords, ids, - fill_value=dtypes.NA): + fill_value=dtypes.NA, join='outer'): if len(datasets) == 0: return Dataset() @@ -254,12 +257,13 @@ def _nested_combine(datasets, concat_dims, compat, data_vars, coords, ids, # Apply series of concatenate or merge operations along each dimension combined = _combine_nd(combined_ids, concat_dims, compat=compat, data_vars=data_vars, coords=coords, - fill_value=fill_value) + fill_value=fill_value, join=join) return combined def combine_nested(datasets, concat_dim, compat='no_conflicts', - data_vars='all', coords='different', fill_value=dtypes.NA): + data_vars='all', coords='different', fill_value=dtypes.NA, + join='outer'): """ Explicitly combine an N-dimensional grid of datasets into one by using a succession of concat and merge operations along each dimension of the grid. @@ -312,6 +316,16 @@ def combine_nested(datasets, concat_dim, compat='no_conflicts', Details are in the documentation of concat fill_value : scalar, optional Value to use for newly missing values + join : {'outer', 'inner', 'left', 'right', 'exact'}, optional + String indicating how to combine differing indexes + (excluding concat_dim) in objects + + - 'outer': use the union of object indexes + - 'inner': use the intersection of object indexes + - 'left': use indexes from the first object with each dimension + - 'right': use indexes from the last object with each dimension + - 'exact': instead of aligning, raise `ValueError` when indexes to be + aligned are not equal Returns ------- @@ -383,7 +397,7 @@ def combine_nested(datasets, concat_dim, compat='no_conflicts', # The IDs argument tells _manual_combine that datasets aren't yet sorted return _nested_combine(datasets, concat_dims=concat_dim, compat=compat, data_vars=data_vars, coords=coords, ids=False, - fill_value=fill_value) + fill_value=fill_value, join=join) def vars_as_keys(ds): @@ -391,7 +405,7 @@ def vars_as_keys(ds): def combine_by_coords(datasets, compat='no_conflicts', data_vars='all', - coords='different', fill_value=dtypes.NA): + coords='different', fill_value=dtypes.NA, join='outer'): """ Attempt to auto-magically combine the given datasets into one by using dimension coordinates. @@ -439,6 +453,16 @@ def combine_by_coords(datasets, compat='no_conflicts', data_vars='all', Details are in the documentation of concat fill_value : scalar, optional Value to use for newly missing values + join : {'outer', 'inner', 'left', 'right', 'exact'}, optional + String indicating how to combine differing indexes + (excluding concat_dim) in objects + + - 'outer': use the union of object indexes + - 'inner': use the intersection of object indexes + - 'left': use indexes from the first object with each dimension + - 'right': use indexes from the last object with each dimension + - 'exact': instead of aligning, raise `ValueError` when indexes to be + aligned are not equal Returns ------- @@ -498,7 +522,7 @@ def combine_by_coords(datasets, compat='no_conflicts', data_vars='all', # Concatenate along all of concat_dims one by one to create single ds concatenated = _combine_nd(combined_ids, concat_dims=concat_dims, data_vars=data_vars, coords=coords, - fill_value=fill_value) + fill_value=fill_value, join=join) # Check the overall coordinates are monotonically increasing for dim in concat_dims: @@ -511,7 +535,7 @@ def combine_by_coords(datasets, compat='no_conflicts', data_vars='all', concatenated_grouped_by_data_vars.append(concatenated) return merge(concatenated_grouped_by_data_vars, compat=compat, - fill_value=fill_value) + fill_value=fill_value, join=join) # Everything beyond here is only needed until the deprecation cycle in #2616 @@ -523,7 +547,7 @@ def combine_by_coords(datasets, compat='no_conflicts', data_vars='all', def auto_combine(datasets, concat_dim='_not_supplied', compat='no_conflicts', data_vars='all', coords='different', fill_value=dtypes.NA, - from_openmfds=False): + join='outer', from_openmfds=False): """ Attempt to auto-magically combine the given datasets into one. @@ -571,6 +595,16 @@ def auto_combine(datasets, concat_dim='_not_supplied', compat='no_conflicts', Details are in the documentation of concat fill_value : scalar, optional Value to use for newly missing values + join : {'outer', 'inner', 'left', 'right', 'exact'}, optional + String indicating how to combine differing indexes + (excluding concat_dim) in objects + + - 'outer': use the union of object indexes + - 'inner': use the intersection of object indexes + - 'left': use indexes from the first object with each dimension + - 'right': use indexes from the last object with each dimension + - 'exact': instead of aligning, raise `ValueError` when indexes to be + aligned are not equal Returns ------- @@ -629,7 +663,8 @@ def auto_combine(datasets, concat_dim='_not_supplied', compat='no_conflicts', return _old_auto_combine(datasets, concat_dim=concat_dim, compat=compat, data_vars=data_vars, - coords=coords, fill_value=fill_value) + coords=coords, fill_value=fill_value, + join=join) def _dimension_coords_exist(datasets): @@ -670,7 +705,7 @@ def _requires_concat_and_merge(datasets): def _old_auto_combine(datasets, concat_dim=_CONCAT_DIM_DEFAULT, compat='no_conflicts', data_vars='all', coords='different', - fill_value=dtypes.NA): + fill_value=dtypes.NA, join='outer'): if concat_dim is not None: dim = None if concat_dim is _CONCAT_DIM_DEFAULT else concat_dim @@ -679,16 +714,17 @@ def _old_auto_combine(datasets, concat_dim=_CONCAT_DIM_DEFAULT, concatenated = [_auto_concat(list(datasets), dim=dim, data_vars=data_vars, coords=coords, - fill_value=fill_value) + fill_value=fill_value, join=join) for vars, datasets in grouped] else: concatenated = datasets - merged = merge(concatenated, compat=compat, fill_value=fill_value) + merged = merge(concatenated, compat=compat, fill_value=fill_value, + join=join) return merged def _auto_concat(datasets, dim=None, data_vars='all', coords='different', - fill_value=dtypes.NA): + fill_value=dtypes.NA, join='outer'): if len(datasets) == 1 and dim is None: # There is nothing more to combine, so kick out early. return datasets[0] diff --git a/xarray/core/concat.py b/xarray/core/concat.py index cd59d87870e..a6570525cc5 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -11,7 +11,7 @@ def concat(objs, dim=None, data_vars='all', coords='different', compat='equals', positions=None, indexers=None, mode=None, - concat_over=None, fill_value=dtypes.NA): + concat_over=None, fill_value=dtypes.NA, join='outer'): """Concatenate xarray objects along a new or existing dimension. Parameters @@ -52,7 +52,7 @@ def concat(objs, dim=None, data_vars='all', coords='different', * 'all': All coordinate variables will be concatenated, except those corresponding to other dimensions. * list of str: The listed coordinate variables will be concatenated, - in addition the 'minimal' coordinates. + in addition to the 'minimal' coordinates. compat : {'equals', 'identical'}, optional String indicating how to compare non-concatenated variables and dataset global attributes for potential conflicts. 'equals' means @@ -65,6 +65,17 @@ def concat(objs, dim=None, data_vars='all', coords='different', supplied, objects are concatenated in the provided order. fill_value : scalar, optional Value to use for newly missing values + join : {'outer', 'inner', 'left', 'right', 'exact'}, optional + String indicating how to combine differing indexes + (excluding dim) in objects + + - 'outer': use the union of object indexes + - 'inner': use the intersection of object indexes + - 'left': use indexes from the first object with each dimension + - 'right': use indexes from the last object with each dimension + - 'exact': instead of aligning, raise `ValueError` when indexes to be + aligned are not equal + indexers, mode, concat_over : deprecated Returns @@ -76,7 +87,7 @@ def concat(objs, dim=None, data_vars='all', coords='different', merge auto_combine """ - # TODO: add join and ignore_index arguments copied from pandas.concat + # TODO: add ignore_index arguments copied from pandas.concat # TODO: support concatenating scalar coordinates even if the concatenated # dimension already exists from .dataset import Dataset @@ -116,7 +127,7 @@ def concat(objs, dim=None, data_vars='all', coords='different', else: raise TypeError('can only concatenate xarray Dataset and DataArray ' 'objects, got %s' % type(first_obj)) - return f(objs, dim, data_vars, coords, compat, positions, fill_value) + return f(objs, dim, data_vars, coords, compat, positions, fill_value, join) def _calc_concat_dim_coord(dim): @@ -212,7 +223,7 @@ def process_subset_opt(opt, subset): def _dataset_concat(datasets, dim, data_vars, coords, compat, positions, - fill_value=dtypes.NA): + fill_value=dtypes.NA, join='outer'): """ Concatenate a sequence of datasets along a new or existing dimension """ @@ -225,7 +236,7 @@ def _dataset_concat(datasets, dim, data_vars, coords, compat, positions, dim, coord = _calc_concat_dim_coord(dim) # Make sure we're working on a copy (we'll be loading variables) datasets = [ds.copy() for ds in datasets] - datasets = align(*datasets, join='outer', copy=False, exclude=[dim], + datasets = align(*datasets, join=join, copy=False, exclude=[dim], fill_value=fill_value) concat_over, equals = _calc_concat_over(datasets, dim, data_vars, coords) @@ -318,7 +329,7 @@ def ensure_common_dims(vars): def _dataarray_concat(arrays, dim, data_vars, coords, compat, - positions, fill_value=dtypes.NA): + positions, fill_value=dtypes.NA, join='outer'): arrays = list(arrays) if data_vars != 'all': @@ -337,5 +348,5 @@ def _dataarray_concat(arrays, dim, data_vars, coords, compat, datasets.append(arr._to_temp_dataset()) ds = _dataset_concat(datasets, dim, data_vars, coords, compat, - positions, fill_value=fill_value) + positions, fill_value=fill_value, join=join) return arrays[0]._from_temp_dataset(ds, name) diff --git a/xarray/core/merge.py b/xarray/core/merge.py index 9c909aa197c..289b70ed518 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -530,7 +530,14 @@ def merge(objects, compat='no_conflicts', join='outer', fill_value=dtypes.NA): must be equal. The returned dataset then contains the combination of all non-null values. join : {'outer', 'inner', 'left', 'right', 'exact'}, optional - How to combine objects with different indexes. + String indicating how to combine differing indexes in objects. + + - 'outer': use the union of object indexes + - 'inner': use the intersection of object indexes + - 'left': use indexes from the first object with each dimension + - 'right': use indexes from the last object with each dimension + - 'exact': instead of aligning, raise `ValueError` when indexes to be + aligned are not equal fill_value : scalar, optional Value to use for newly missing values diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 026ae6a55ff..bef9da9fb7f 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -2392,8 +2392,12 @@ class TestOpenMFDatasetWithDataVarsAndCoordsKw: var_name = 'v1' @contextlib.contextmanager - def setup_files_and_datasets(self): + def setup_files_and_datasets(self, fuzz=0): ds1, ds2 = self.gen_datasets_with_common_coord_and_time() + + # to test join='exact' + ds1['x'] = ds1.x + fuzz + with create_tmp_file() as tmpfile1: with create_tmp_file() as tmpfile2: @@ -2430,20 +2434,29 @@ def gen_datasets_with_common_coord_and_time(self): return ds1, ds2 + @pytest.mark.parametrize('combine', ['nested', 'by_coords']) @pytest.mark.parametrize('opt', ['all', 'minimal', 'different']) - def test_open_mfdataset_does_same_as_concat(self, opt): + @pytest.mark.parametrize('join', ['outer', 'inner', 'left', 'right']) + def test_open_mfdataset_does_same_as_concat(self, combine, opt, join): with self.setup_files_and_datasets() as (files, [ds1, ds2]): - with open_mfdataset(files, data_vars=opt, - combine='nested', concat_dim='t') as ds: - kwargs = dict(data_vars=opt, dim='t') - ds_expect = xr.concat([ds1, ds2], **kwargs) - assert_identical(ds, ds_expect) - with open_mfdataset(files, coords=opt, - combine='nested', concat_dim='t') as ds: - kwargs = dict(coords=opt, dim='t') - ds_expect = xr.concat([ds1, ds2], **kwargs) + if combine == 'by_coords': + files.reverse() + with open_mfdataset(files, data_vars=opt, combine=combine, + concat_dim='t', join=join) as ds: + ds_expect = xr.concat([ds1, ds2], data_vars=opt, dim='t', + join=join) assert_identical(ds, ds_expect) + @pytest.mark.parametrize('combine', ['nested', 'by_coords']) + @pytest.mark.parametrize('opt', ['all', 'minimal', 'different']) + def test_open_mfdataset_exact_join_raises_error(self, combine, opt): + with self.setup_files_and_datasets(fuzz=0.1) as (files, [ds1, ds2]): + if combine == 'by_coords': + files.reverse() + with raises_regex(ValueError, 'indexes along dimension'): + open_mfdataset(files, data_vars=opt, combine=combine, + concat_dim='t', join='exact') + def test_common_coord_when_datavars_all(self): opt = 'all' diff --git a/xarray/tests/test_combine.py b/xarray/tests/test_combine.py index 2a71a3a3ed4..8c9308466a4 100644 --- a/xarray/tests/test_combine.py +++ b/xarray/tests/test_combine.py @@ -306,8 +306,8 @@ def test_check_lengths(self): _check_shape_tile_ids(combined_tile_ids) -class TestManualCombine: - def test_manual_concat(self): +class TestNestedCombine: + def test_nested_concat(self): objs = [Dataset({'x': [0]}), Dataset({'x': [1]})] expected = Dataset({'x': [0, 1]}) actual = combine_nested(objs, concat_dim='x') @@ -326,7 +326,7 @@ def test_manual_concat(self): expected = Dataset({'x': [0, 1, 2]}) assert_identical(expected, actual) - # ensure manual_combine handles non-sorted variables + # ensure combine_nested handles non-sorted variables objs = [Dataset(OrderedDict([('x', ('a', [0])), ('y', ('a', [0]))])), Dataset(OrderedDict([('y', ('a', [1])), ('x', ('a', [1]))]))] actual = combine_nested(objs, concat_dim='a') @@ -337,17 +337,37 @@ def test_manual_concat(self): with pytest.raises(KeyError): combine_nested(objs, concat_dim='x') + @pytest.mark.parametrize( + "join, expected", + [ + ('outer', Dataset({'x': [0, 1], 'y': [0, 1]})), + ('inner', Dataset({'x': [0, 1], 'y': []})), + ('left', Dataset({'x': [0, 1], 'y': [0]})), + ('right', Dataset({'x': [0, 1], 'y': [1]})), + ]) + def test_combine_nested_join(self, join, expected): + objs = [Dataset({'x': [0], 'y': [0]}), + Dataset({'x': [1], 'y': [1]})] + actual = combine_nested(objs, concat_dim='x', join=join) + assert_identical(expected, actual) + + def test_combine_nested_join_exact(self): + objs = [Dataset({'x': [0], 'y': [0]}), + Dataset({'x': [1], 'y': [1]})] + with raises_regex(ValueError, 'indexes along dimension'): + combine_nested(objs, concat_dim='x', join='exact') + def test_empty_input(self): assert_identical(Dataset(), combine_nested([], concat_dim='x')) # Fails because of concat's weird treatment of dimension coords, see #2975 @pytest.mark.xfail - def test_manual_concat_too_many_dims_at_once(self): + def test_nested_concat_too_many_dims_at_once(self): objs = [Dataset({'x': [0], 'y': [1]}), Dataset({'y': [0], 'x': [1]})] with pytest.raises(ValueError, match="not equal across datasets"): combine_nested(objs, concat_dim='x', coords='minimal') - def test_manual_concat_along_new_dim(self): + def test_nested_concat_along_new_dim(self): objs = [Dataset({'a': ('x', [10]), 'x': [0]}), Dataset({'a': ('x', [20]), 'x': [0]})] expected = Dataset({'a': (('t', 'x'), [[10], [20]]), 'x': [0]}) @@ -361,7 +381,7 @@ def test_manual_concat_along_new_dim(self): actual = combine_nested(objs, concat_dim=dim) assert_identical(expected, actual) - def test_manual_merge(self): + def test_nested_merge(self): data = Dataset({'x': 0}) actual = combine_nested([data, data, data], concat_dim=None) assert_identical(data, actual) @@ -450,7 +470,7 @@ def test_auto_combine_2d(self): result = combine_nested(datasets, concat_dim=['dim1', 'dim2']) assert_equal(result, expected) - def test_manual_combine_missing_data_new_dim(self): + def test_combine_nested_missing_data_new_dim(self): # Your data includes "time" and "station" dimensions, and each year's # data has a different set of stations. datasets = [Dataset({'a': ('x', [2, 3]), 'x': [1, 2]}), @@ -513,7 +533,7 @@ def test_combine_concat_over_redundant_nesting(self): expected = Dataset({'x': [0]}) assert_identical(expected, actual) - def test_manual_combine_but_need_auto_combine(self): + def test_combine_nested_but_need_auto_combine(self): objs = [Dataset({'x': [0, 1]}), Dataset({'x': [2], 'wall': [0]})] with raises_regex(ValueError, 'cannot be combined'): combine_nested(objs, concat_dim='x') @@ -574,6 +594,26 @@ def test_combine_by_coords(self): def test_empty_input(self): assert_identical(Dataset(), combine_by_coords([])) + @pytest.mark.parametrize( + "join, expected", + [ + ('outer', Dataset({'x': [0, 1], 'y': [0, 1]})), + ('inner', Dataset({'x': [0, 1], 'y': []})), + ('left', Dataset({'x': [0, 1], 'y': [0]})), + ('right', Dataset({'x': [0, 1], 'y': [1]})), + ]) + def test_combine_coords_join(self, join, expected): + objs = [Dataset({'x': [0], 'y': [0]}), + Dataset({'x': [1], 'y': [1]})] + actual = combine_nested(objs, concat_dim='x', join=join) + assert_identical(expected, actual) + + def test_combine_coords_join_exact(self): + objs = [Dataset({'x': [0], 'y': [0]}), + Dataset({'x': [1], 'y': [1]})] + with raises_regex(ValueError, 'indexes along dimension'): + combine_nested(objs, concat_dim='x', join='exact') + def test_infer_order_from_coords(self): data = create_test_data() objs = [data.isel(dim2=slice(4, 9)), data.isel(dim2=slice(4))] diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index 6218f752bb7..deed6748761 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -164,6 +164,32 @@ def test_concat_errors(self): with raises_regex(ValueError, 'no longer a valid'): concat([data, data], 'new_dim', concat_over='different') + def test_concat_join_kwarg(self): + ds1 = Dataset({'a': (('x', 'y'), [[0]])}, + coords={'x': [0], 'y': [0]}) + ds2 = Dataset({'a': (('x', 'y'), [[0]])}, + coords={'x': [1], 'y': [0.0001]}) + + expected = dict() + expected['outer'] = Dataset({'a': (('x', 'y'), + [[0, np.nan], [np.nan, 0]])}, + {'x': [0, 1], 'y': [0, 0.0001]}) + expected['inner'] = Dataset({'a': (('x', 'y'), [[], []])}, + {'x': [0, 1], 'y': []}) + expected['left'] = Dataset({'a': (('x', 'y'), + np.array([0, np.nan], ndmin=2).T)}, + coords={'x': [0, 1], 'y': [0]}) + expected['right'] = Dataset({'a': (('x', 'y'), + np.array([np.nan, 0], ndmin=2).T)}, + coords={'x': [0, 1], 'y': [0.0001]}) + + with raises_regex(ValueError, "indexes along dimension 'y'"): + actual = concat([ds1, ds2], join='exact', dim='x') + + for join in expected: + actual = concat([ds1, ds2], join=join, dim='x') + assert_equal(actual, expected[join]) + def test_concat_promote_shape(self): # mixed dims within variables objs = [Dataset({}, {'x': 0}), Dataset({'x': [1]})] @@ -318,3 +344,29 @@ def test_concat_fill_value(self, fill_value): dims=['y', 'x'], coords={'x': [1, 2, 3]}) actual = concat((foo, bar), dim='y', fill_value=fill_value) assert_identical(actual, expected) + + def test_concat_join_kwarg(self): + ds1 = Dataset({'a': (('x', 'y'), [[0]])}, + coords={'x': [0], 'y': [0]}).to_array() + ds2 = Dataset({'a': (('x', 'y'), [[0]])}, + coords={'x': [1], 'y': [0.0001]}).to_array() + + expected = dict() + expected['outer'] = Dataset({'a': (('x', 'y'), + [[0, np.nan], [np.nan, 0]])}, + {'x': [0, 1], 'y': [0, 0.0001]}) + expected['inner'] = Dataset({'a': (('x', 'y'), [[], []])}, + {'x': [0, 1], 'y': []}) + expected['left'] = Dataset({'a': (('x', 'y'), + np.array([0, np.nan], ndmin=2).T)}, + coords={'x': [0, 1], 'y': [0]}) + expected['right'] = Dataset({'a': (('x', 'y'), + np.array([np.nan, 0], ndmin=2).T)}, + coords={'x': [0, 1], 'y': [0.0001]}) + + with raises_regex(ValueError, "indexes along dimension 'y'"): + actual = concat([ds1, ds2], join='exact', dim='x') + + for join in expected: + actual = concat([ds1, ds2], join=join, dim='x') + assert_equal(actual, expected[join].to_array()) From 8a9c4710b2ee389a41e08a665108aca05ef02544 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 7 Aug 2019 18:26:00 +0100 Subject: [PATCH 10/10] pyupgrade one-off run (#3190) * pyupgrade (manually vetted and tweaked) * pyupgrade * Tweaks to Dataset.drop_dims() * mypy * More concise code --- setup.cfg | 2 ++ versioneer.py | 23 +++++++------- xarray/_version.py | 10 +++--- xarray/backends/api.py | 4 +-- xarray/backends/netCDF4_.py | 12 ++++--- xarray/backends/netcdf3.py | 9 +++--- xarray/backends/pseudonetcdf_.py | 14 ++++----- xarray/backends/pynio_.py | 10 +++--- xarray/backends/zarr.py | 9 +++--- xarray/coding/cftime_offsets.py | 2 +- xarray/coding/times.py | 2 +- xarray/conventions.py | 2 +- xarray/convert.py | 6 ++-- xarray/core/alignment.py | 8 ++--- xarray/core/combine.py | 13 ++++---- xarray/core/dataarray.py | 2 +- xarray/core/dataset.py | 49 +++++++++++++++-------------- xarray/core/formatting.py | 14 +++++---- xarray/core/groupby.py | 36 +++++++++++++-------- xarray/core/indexing.py | 2 +- xarray/core/ops.py | 6 ++-- xarray/core/variable.py | 18 ++++++----- xarray/plot/plot.py | 2 +- xarray/tests/test_backends.py | 22 +++++++------ xarray/tests/test_coding_strings.py | 4 +-- xarray/tests/test_coding_times.py | 2 +- xarray/tests/test_concat.py | 11 +++++-- xarray/tests/test_conventions.py | 6 ++-- xarray/tests/test_dataarray.py | 6 ++-- xarray/tests/test_dataset.py | 34 ++++++++++++-------- xarray/tests/test_plot.py | 8 ++--- xarray/tests/test_sparse.py | 8 ++--- xarray/tutorial.py | 2 +- 33 files changed, 202 insertions(+), 156 deletions(-) diff --git a/setup.cfg b/setup.cfg index 8b76dd27879..128550071cc 100644 --- a/setup.cfg +++ b/setup.cfg @@ -83,6 +83,8 @@ ignore_missing_imports = True ignore_missing_imports = True [mypy-seaborn.*] ignore_missing_imports = True +[mypy-sparse.*] +ignore_missing_imports = True [mypy-toolz.*] ignore_missing_imports = True [mypy-zarr.*] diff --git a/versioneer.py b/versioneer.py index 577743023ca..e369108b439 100644 --- a/versioneer.py +++ b/versioneer.py @@ -398,7 +398,7 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, stderr=(subprocess.PIPE if hide_stderr else None)) break - except EnvironmentError: + except OSError: e = sys.exc_info()[1] if e.errno == errno.ENOENT: continue @@ -421,7 +421,7 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, return stdout, p.returncode -LONG_VERSION_PY['git'] = ''' +LONG_VERSION_PY['git'] = r''' # This file helps to compute a version number in source trees obtained from # git-archive tarball (such as those provided by githubs download-from-tag # feature). Distribution tarballs (built by setup.py sdist) and build @@ -968,7 +968,7 @@ def git_get_keywords(versionfile_abs): if mo: keywords["date"] = mo.group(1) f.close() - except EnvironmentError: + except OSError: pass return keywords @@ -992,11 +992,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): if verbose: print("keywords are unexpanded, not using") raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = set([r.strip() for r in refnames.strip("()").split(",")]) + refs = {r.strip() for r in refnames.strip("()").split(",")} # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) + tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d @@ -1005,7 +1005,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r'\d', r)]) + tags = {r for r in refs if re.search(r'\d', r)} if verbose: print("discarding '%s', no digits" % ",".join(refs - tags)) if verbose: @@ -1148,7 +1148,7 @@ def do_vcs_install(manifest_in, versionfile_source, ipy): if "export-subst" in line.strip().split()[1:]: present = True f.close() - except EnvironmentError: + except OSError: pass if not present: f = open(".gitattributes", "a+") @@ -1206,7 +1206,7 @@ def versions_from_file(filename): try: with open(filename) as f: contents = f.read() - except EnvironmentError: + except OSError: raise NotThisMethod("unable to read _version.py") mo = re.search(r"version_json = '''\n(.*)''' # END VERSION_JSON", contents, re.M | re.S) @@ -1702,8 +1702,7 @@ def do_setup(): root = get_root() try: cfg = get_config_from_root(root) - except (EnvironmentError, configparser.NoSectionError, - configparser.NoOptionError) as e: + except (OSError, configparser.NoSectionError, configparser.NoOptionError) as e: if isinstance(e, (EnvironmentError, configparser.NoSectionError)): print("Adding sample versioneer config to setup.cfg", file=sys.stderr) @@ -1728,7 +1727,7 @@ def do_setup(): try: with open(ipy, "r") as f: old = f.read() - except EnvironmentError: + except OSError: old = "" if INIT_PY_SNIPPET not in old: print(" appending to %s" % ipy) @@ -1752,7 +1751,7 @@ def do_setup(): if line.startswith("include "): for include in line.split()[1:]: simple_includes.add(include) - except EnvironmentError: + except OSError: pass # That doesn't cover everything MANIFEST.in can do # (http://docs.python.org/2/distutils/sourcedist.html#commands), so diff --git a/xarray/_version.py b/xarray/_version.py index df4ee95ade4..442e56a04b0 100644 --- a/xarray/_version.py +++ b/xarray/_version.py @@ -81,7 +81,7 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, stderr=(subprocess.PIPE if hide_stderr else None)) break - except EnvironmentError: + except OSError: e = sys.exc_info()[1] if e.errno == errno.ENOENT: continue @@ -153,7 +153,7 @@ def git_get_keywords(versionfile_abs): if mo: keywords["date"] = mo.group(1) f.close() - except EnvironmentError: + except OSError: pass return keywords @@ -177,11 +177,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): if verbose: print("keywords are unexpanded, not using") raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = set([r.strip() for r in refnames.strip("()").split(",")]) + refs = {r.strip() for r in refnames.strip("()").split(",")} # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) + tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d @@ -190,7 +190,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r'\d', r)]) + tags = {r for r in refs if re.search(r'\d', r)} if verbose: print("discarding '%s', no digits" % ",".join(refs - tags)) if verbose: diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 2535c7118a5..292373d2a33 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -752,7 +752,7 @@ def open_mfdataset(paths, chunks=None, concat_dim='_not_supplied', paths = [str(p) if isinstance(p, Path) else p for p in paths] if not paths: - raise IOError('no files to open') + raise OSError('no files to open') # If combine='by_coords' then this is unnecessary, but quick. # If combine='nested' then this creates a flat list which is easier to @@ -1051,7 +1051,7 @@ def save_mfdataset(datasets, paths, mode='w', format=None, groups=None, if groups is None: groups = [None] * len(datasets) - if len(set([len(datasets), len(paths), len(groups)])) > 1: + if len({len(datasets), len(paths), len(groups)}) > 1: raise ValueError('must supply lists of the same length for the ' 'datasets, paths and groups arguments to ' 'save_mfdataset') diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 962cba4012d..a93fba65d18 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -138,7 +138,7 @@ def _netcdf4_create_group(dataset, name): def _nc4_require_group(ds, group, mode, create_group=_netcdf4_create_group): - if group in set([None, '', '/']): + if group in {None, '', '/'}: # use the root group return ds else: @@ -155,7 +155,7 @@ def _nc4_require_group(ds, group, mode, create_group=_netcdf4_create_group): ds = create_group(ds, key) else: # wrap error to provide slightly more helpful message - raise IOError('group not found: %s' % key, e) + raise OSError('group not found: %s' % key, e) return ds @@ -195,9 +195,11 @@ def _extract_nc4_variable_encoding(variable, raise_on_invalid=False, encoding = variable.encoding.copy() - safe_to_drop = set(['source', 'original_shape']) - valid_encodings = set(['zlib', 'complevel', 'fletcher32', 'contiguous', - 'chunksizes', 'shuffle', '_FillValue', 'dtype']) + safe_to_drop = {'source', 'original_shape'} + valid_encodings = { + 'zlib', 'complevel', 'fletcher32', 'contiguous', + 'chunksizes', 'shuffle', '_FillValue', 'dtype' + } if lsd_okay: valid_encodings.add('least_significant_digit') if h5py_okay: diff --git a/xarray/backends/netcdf3.py b/xarray/backends/netcdf3.py index 7f5c8d4b1a7..4985e51f689 100644 --- a/xarray/backends/netcdf3.py +++ b/xarray/backends/netcdf3.py @@ -11,9 +11,10 @@ # The following are reserved names in CDL and may not be used as names of # variables, dimension, attributes -_reserved_names = set(['byte', 'char', 'short', 'ushort', 'int', 'uint', - 'int64', 'uint64', 'float' 'real', 'double', 'bool', - 'string']) +_reserved_names = { + 'byte', 'char', 'short', 'ushort', 'int', 'uint', 'int64', 'uint64', + 'float' 'real', 'double', 'bool', 'string' +} # These data-types aren't supported by netCDF3, so they are automatically # coerced instead as indicated by the "coerce_nc3_dtype" function @@ -108,4 +109,4 @@ def is_valid_nc3_name(s): ('/' not in s) and (s[-1] != ' ') and (_isalnumMUTF8(s[0]) or (s[0] == '_')) and - all((_isalnumMUTF8(c) or c in _specialchars for c in s))) + all(_isalnumMUTF8(c) or c in _specialchars for c in s)) diff --git a/xarray/backends/pseudonetcdf_.py b/xarray/backends/pseudonetcdf_.py index 7a3f8a771e6..34a61ae8108 100644 --- a/xarray/backends/pseudonetcdf_.py +++ b/xarray/backends/pseudonetcdf_.py @@ -75,18 +75,18 @@ def get_variables(self): for k, v in self.ds.variables.items()) def get_attrs(self): - return Frozen(dict([(k, getattr(self.ds, k)) - for k in self.ds.ncattrs()])) + return Frozen({k: getattr(self.ds, k) for k in self.ds.ncattrs()}) def get_dimensions(self): return Frozen(self.ds.dimensions) def get_encoding(self): - encoding = {} - encoding['unlimited_dims'] = set( - [k for k in self.ds.dimensions - if self.ds.dimensions[k].isunlimited()]) - return encoding + return { + 'unlimited_dims': { + k for k in self.ds.dimensions + if self.ds.dimensions[k].isunlimited() + } + } def close(self): self._manager.close() diff --git a/xarray/backends/pynio_.py b/xarray/backends/pynio_.py index f8033551f96..9c3946f657d 100644 --- a/xarray/backends/pynio_.py +++ b/xarray/backends/pynio_.py @@ -75,10 +75,12 @@ def get_dimensions(self): return Frozen(self.ds.dimensions) def get_encoding(self): - encoding = {} - encoding['unlimited_dims'] = set( - [k for k in self.ds.dimensions if self.ds.unlimited(k)]) - return encoding + return { + 'unlimited_dims': { + k for k in self.ds.dimensions + if self.ds.unlimited(k) + } + } def close(self): self._manager.close() diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index c0634fff009..effacd8b4b7 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -166,8 +166,7 @@ def _get_zarr_dims_and_attrs(zarr_obj, dimension_key): def _extract_zarr_variable_encoding(variable, raise_on_invalid=False): encoding = variable.encoding.copy() - valid_encodings = set(['chunks', 'compressor', 'filters', - 'cache_metadata']) + valid_encodings = {'chunks', 'compressor', 'filters', 'cache_metadata'} if raise_on_invalid: invalid = [k for k in encoding if k not in valid_encodings] @@ -340,8 +339,10 @@ def store(self, variables, attributes, check_encoding_set=frozenset(), only needed in append mode """ - existing_variables = set([vn for vn in variables - if _encode_variable_name(vn) in self.ds]) + existing_variables = { + vn for vn in variables + if _encode_variable_name(vn) in self.ds + } new_variables = set(variables) - existing_variables variables_without_encoding = OrderedDict([(vn, variables[vn]) for vn in new_variables]) diff --git a/xarray/coding/cftime_offsets.py b/xarray/coding/cftime_offsets.py index 400cfe11d33..7187f1266bd 100644 --- a/xarray/coding/cftime_offsets.py +++ b/xarray/coding/cftime_offsets.py @@ -637,7 +637,7 @@ def __apply__(self, other): _FREQUENCY_CONDITION = '|'.join(_FREQUENCIES.keys()) -_PATTERN = r'^((?P\d+)|())(?P({0}))$'.format( +_PATTERN = r'^((?P\d+)|())(?P({}))$'.format( _FREQUENCY_CONDITION) diff --git a/xarray/coding/times.py b/xarray/coding/times.py index ea18b402ad2..4930a77d022 100644 --- a/xarray/coding/times.py +++ b/xarray/coding/times.py @@ -23,7 +23,7 @@ # standard calendars recognized by cftime -_STANDARD_CALENDARS = set(['standard', 'gregorian', 'proleptic_gregorian']) +_STANDARD_CALENDARS = {'standard', 'gregorian', 'proleptic_gregorian'} _NS_PER_TIME_DELTA = {'us': int(1e3), 'ms': int(1e6), diff --git a/xarray/conventions.py b/xarray/conventions.py index d0d90242426..616e557efcd 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -186,7 +186,7 @@ def ensure_dtype_not_object(var, name=None): if strings.is_bytes_dtype(inferred_dtype): fill_value = b'' elif strings.is_unicode_dtype(inferred_dtype): - fill_value = u'' + fill_value = '' else: # insist on using float for numeric values if not np.issubdtype(inferred_dtype, np.floating): diff --git a/xarray/convert.py b/xarray/convert.py index b8c0c2a7eca..83055631bb5 100644 --- a/xarray/convert.py +++ b/xarray/convert.py @@ -30,7 +30,7 @@ def encode(var): def _filter_attrs(attrs, ignored_attrs): """ Return attrs that are not in ignored_attrs """ - return dict((k, v) for k, v in attrs.items() if k not in ignored_attrs) + return {k: v for k, v in attrs.items() if k not in ignored_attrs} def from_cdms2(variable): @@ -119,7 +119,7 @@ def set_cdms2_attrs(var, attrs): def _pick_attrs(attrs, keys): """ Return attrs with keys in keys list """ - return dict((k, v) for k, v in attrs.items() if k in keys) + return {k: v for k, v in attrs.items() if k in keys} def _get_iris_args(attrs): @@ -188,7 +188,7 @@ def _iris_obj_to_attrs(obj): if obj.units.origin != '1' and not obj.units.is_unknown(): attrs['units'] = obj.units.origin attrs.update(obj.attributes) - return dict((k, v) for k, v in attrs.items() if v is not None) + return {k: v for k, v in attrs.items() if v is not None} def _iris_cell_methods_to_str(cell_methods_obj): diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index 711634a95ca..1db9157850a 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -341,10 +341,10 @@ def reindex_variables( for dim, indexer in indexers.items(): if isinstance(indexer, DataArray) and indexer.dims != (dim,): warnings.warn( - "Indexer has dimensions {0:s} that are different " - "from that to be indexed along {1:s}. " - "This will behave differently in the future.".format( - str(indexer.dims), dim), + "Indexer has dimensions {:s} that are different " + "from that to be indexed along {:s}. " + "This will behave differently in the future." + .format(str(indexer.dims), dim), FutureWarning, stacklevel=3) target = new_indexes[dim] = utils.safe_cast_to_index(indexers[dim]) diff --git a/xarray/core/combine.py b/xarray/core/combine.py index 71da4e4e094..6a61cb2addc 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -41,9 +41,8 @@ def _infer_tile_ids_from_nested_list(entry, current_pos): if isinstance(entry, list): for i, item in enumerate(entry): - for result in _infer_tile_ids_from_nested_list(item, - current_pos + (i,)): - yield result + yield from _infer_tile_ids_from_nested_list( + item, current_pos + (i,)) else: yield current_pos, entry @@ -735,10 +734,12 @@ def _auto_concat(datasets, dim=None, data_vars='all', coords='different', concat_dims = set(ds0.dims) if ds0.dims != ds1.dims: dim_tuples = set(ds0.dims.items()) - set(ds1.dims.items()) - concat_dims = set(i for i, _ in dim_tuples) + concat_dims = {i for i, _ in dim_tuples} if len(concat_dims) > 1: - concat_dims = set(d for d in concat_dims - if not ds0[d].equals(ds1[d])) + concat_dims = { + d for d in concat_dims + if not ds0[d].equals(ds1[d]) + } if len(concat_dims) > 1: raise ValueError('too many different dimensions to ' 'concatenate: %s' % concat_dims) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 19d595079e5..70d11fe18ca 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1389,7 +1389,7 @@ def expand_dims(self, dim: Union[None, Hashable, Sequence[Hashable], elif isinstance(dim, Sequence) and not isinstance(dim, str): if len(dim) != len(set(dim)): raise ValueError('dims should not contain duplicate values.') - dim = OrderedDict(((d, 1) for d in dim)) + dim = OrderedDict((d, 1) for d in dim) elif dim is not None and not isinstance(dim, Mapping): dim = OrderedDict(((cast(Hashable, dim), 1),)) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 5d3ca932ccc..3d2ef53a034 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -113,7 +113,7 @@ def calculate_dimensions( """ dims = {} # type: Dict[Any, int] last_used = {} - scalar_vars = set(k for k, v in variables.items() if not v.dims) + scalar_vars = {k for k, v in variables.items() if not v.dims} for k, var in variables.items(): for dim, size in zip(var.dims, var.shape): if dim in scalar_vars: @@ -997,7 +997,7 @@ def _copy_listed(self, names: Iterable[Hashable]) -> 'Dataset': for v in variables.values(): needed_dims.update(v.dims) - dims = dict((k, self.dims[k]) for k in needed_dims) + dims = {k: self.dims[k] for k in needed_dims} for k in self._coord_names: if set(self.variables[k].dims) <= needed_dims: @@ -1569,7 +1569,7 @@ def chunk( def selkeys(dict_, keys): if dict_ is None: return None - return dict((d, dict_[d]) for d in keys if d in dict_) + return {d: dict_[d] for d in keys if d in dict_} def maybe_chunk(name, var, chunks): chunks = selkeys(chunks, var.dims) @@ -1923,7 +1923,7 @@ def relevant_keys(mapping): raise ValueError('Indexers must be 1 dimensional') # all the indexers should have the same length - lengths = set(len(v) for k, v in indexers) + lengths = {len(v) for k, v in indexers} if len(lengths) > 1: raise ValueError('All indexers must be the same length') @@ -2577,7 +2577,7 @@ def swap_dims( 'variable along the old dimension %r' % (v, k)) - result_dims = set(dims_dict.get(dim, dim) for dim in self.dims) + result_dims = {dims_dict.get(dim, dim) for dim in self.dims} coord_names = self._coord_names.copy() coord_names.update(dims_dict.values()) @@ -2674,7 +2674,7 @@ def expand_dims( elif isinstance(dim, Sequence): if len(dim) != len(set(dim)): raise ValueError('dims should not contain duplicate values.') - dim = OrderedDict(((d, 1) for d in dim)) + dim = OrderedDict((d, 1) for d in dim) dim = either_dict_or_kwargs(dim, dim_kwargs, 'expand_dims') assert isinstance(dim, MutableMapping) @@ -2905,7 +2905,7 @@ def _stack_once(self, dims, new_dim): idx = utils.multiindex_from_product_levels(levels, names=dims) variables[new_dim] = IndexVariable(new_dim, idx) - coord_names = set(self._coord_names) - set(dims) | set([new_dim]) + coord_names = set(self._coord_names) - set(dims) | {new_dim} indexes = OrderedDict((k, v) for k, v in self.indexes.items() if k not in dims) @@ -3103,7 +3103,7 @@ def _unstack_once(self, dim: Hashable) -> 'Dataset': variables[name] = IndexVariable(name, lev) indexes[name] = lev - coord_names = set(self._coord_names) - set([dim]) | set(new_dim_names) + coord_names = set(self._coord_names) - {dim} | set(new_dim_names) return self._replace_with_new_dims( variables, coord_names=coord_names, indexes=indexes) @@ -3348,7 +3348,7 @@ def _drop_vars( variables = OrderedDict((k, v) for k, v in self._variables.items() if k not in names) - coord_names = set(k for k in self._coord_names if k in variables) + coord_names = {k for k in self._coord_names if k in variables} indexes = OrderedDict((k, v) for k, v in self.indexes.items() if k not in names) return self._replace_with_new_dims( @@ -3364,7 +3364,7 @@ def drop_dims( Parameters ---------- - drop_dims : str or list + drop_dims : hashable or iterable of hashable Dimension or dimensions to drop. errors: {'raise', 'ignore'}, optional If 'raise' (default), raises a ValueError error if any of the @@ -3386,18 +3386,20 @@ def drop_dims( raise ValueError('errors must be either "raise" or "ignore"') if isinstance(drop_dims, str) or not isinstance(drop_dims, Iterable): - drop_dims = [drop_dims] + drop_dims = {drop_dims} else: - drop_dims = list(drop_dims) + drop_dims = set(drop_dims) if errors == 'raise': - missing_dimensions = [d for d in drop_dims if d not in self.dims] - if missing_dimensions: + missing_dims = drop_dims - set(self.dims) + if missing_dims: raise ValueError('Dataset does not contain the dimensions: %s' - % missing_dimensions) + % missing_dims) - drop_vars = set(k for k, v in self._variables.items() - for d in v.dims if d in drop_dims) + drop_vars = { + k for k, v in self._variables.items() + if set(v.dims) & drop_dims + } return self._drop_vars(drop_vars) def transpose(self, *dims: Hashable) -> 'Dataset': @@ -3740,7 +3742,7 @@ def reduce( allow_lazy=allow_lazy, **kwargs) - coord_names = set(k for k in self.coords if k in variables) + coord_names = {k for k in self.coords if k in variables} indexes = OrderedDict((k, v) for k, v in self.indexes.items() if k in variables) attrs = self.attrs if keep_attrs else None @@ -4079,7 +4081,7 @@ def from_dict(cls, d): DataArray.from_dict """ - if not set(['coords', 'data_vars']).issubset(set(d)): + if not {'coords', 'data_vars'}.issubset(set(d)): variables = d.items() else: import itertools @@ -4250,8 +4252,9 @@ def diff(self, dim, n=1, label='upper'): if n == 0: return self if n < 0: - raise ValueError('order `n` must be non-negative but got {0}' - ''.format(n)) + raise ValueError( + 'order `n` must be non-negative but got {}'.format(n) + ) # prepare slices kwargs_start = {dim: slice(None, -1)} @@ -4530,7 +4533,7 @@ def quantile(self, q, dim=None, interpolation='linear', """ if isinstance(dim, str): - dims = set([dim]) + dims = {dim} elif dim is None: dims = set(self.dims) else: @@ -4561,7 +4564,7 @@ def quantile(self, q, dim=None, interpolation='linear', variables[name] = var # construct the new dataset - coord_names = set(k for k in self.coords if k in variables) + coord_names = {k for k in self.coords if k in variables} indexes = OrderedDict( (k, v) for k, v in self.indexes.items() if k in variables ) diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 00c813ece09..3ddffec8e5e 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -145,7 +145,7 @@ def format_item(x, timedelta_format=None, quote_strings=True): elif isinstance(x, (str, bytes)): return repr(x) if quote_strings else x elif isinstance(x, (float, np.float)): - return '{0:.4}'.format(x) + return '{:.4}'.format(x) else: return str(x) @@ -399,7 +399,7 @@ def short_data_repr(array): elif array._in_memory or array.size < 1e5: return short_array_repr(array.data) else: - return u'[{} values with dtype={}]'.format(array.size, array.dtype) + return '[{} values with dtype={}]'.format(array.size, array.dtype) def array_repr(arr): @@ -409,10 +409,12 @@ def array_repr(arr): else: name_str = '' - summary = [''.format( - type(arr).__name__, name_str, dim_summary(arr))] - - summary.append(short_data_repr(arr)) + summary = [ + ''.format( + type(arr).__name__, name_str, dim_summary(arr) + ), + short_data_repr(arr) + ] if hasattr(arr, 'coords'): if arr.coords: diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 0649ecab44f..2be0857a4d3 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -46,20 +46,30 @@ def _dummy_copy(xarray_obj): from .dataset import Dataset from .dataarray import DataArray if isinstance(xarray_obj, Dataset): - res = Dataset(dict((k, dtypes.get_fill_value(v.dtype)) - for k, v in xarray_obj.data_vars.items()), - dict((k, dtypes.get_fill_value(v.dtype)) - for k, v in xarray_obj.coords.items() - if k not in xarray_obj.dims), - xarray_obj.attrs) + res = Dataset( + { + k: dtypes.get_fill_value(v.dtype) + for k, v in xarray_obj.data_vars.items() + }, + { + k: dtypes.get_fill_value(v.dtype) + for k, v in xarray_obj.coords.items() + if k not in xarray_obj.dims + }, + xarray_obj.attrs + ) elif isinstance(xarray_obj, DataArray): - res = DataArray(dtypes.get_fill_value(xarray_obj.dtype), - dict((k, dtypes.get_fill_value(v.dtype)) - for k, v in xarray_obj.coords.items() - if k not in xarray_obj.dims), - dims=[], - name=xarray_obj.name, - attrs=xarray_obj.attrs) + res = DataArray( + dtypes.get_fill_value(xarray_obj.dtype), + { + k: dtypes.get_fill_value(v.dtype) + for k, v in xarray_obj.coords.items() + if k not in xarray_obj.dims + }, + dims=[], + name=xarray_obj.name, + attrs=xarray_obj.attrs + ) else: # pragma: no cover raise AssertionError return res diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index aea5a5a3f4f..a9ad55e2652 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -149,7 +149,7 @@ def convert_label_indexer(index, label, index_name='', method=None, raise ValueError('cannot use a dict-like object for selection on ' 'a dimension that does not have a MultiIndex') elif len(label) == index.nlevels and not is_nested_vals: - indexer = index.get_loc(tuple((label[k] for k in index.names))) + indexer = index.get_loc(tuple(label[k] for k in index.names)) else: for k, v in label.items(): # index should be an item (i.e. Hashable) not an array-like diff --git a/xarray/core/ops.py b/xarray/core/ops.py index 3759a7c5634..0c0fc1e50a8 100644 --- a/xarray/core/ops.py +++ b/xarray/core/ops.py @@ -258,8 +258,10 @@ def get_op(name): return getattr(operator, op_str(name)) -NON_INPLACE_OP = dict((get_op('i' + name), get_op(name)) - for name in NUM_BINARY_OPS) +NON_INPLACE_OP = { + get_op('i' + name): get_op(name) + for name in NUM_BINARY_OPS +} def inplace_to_noninplace_op(f): diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 3c9d85f13d7..85f26d85cd4 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -537,9 +537,10 @@ def _validate_indexers(self, key): if k.dtype.kind == 'b': if self.shape[self.get_axis_num(dim)] != len(k): raise IndexError( - "Boolean array size {0:d} is used to index array " - "with shape {1:s}.".format(len(k), - str(self.shape))) + "Boolean array size {:d} is used to index array " + "with shape {:s}." + .format(len(k), str(self.shape)) + ) if k.ndim > 1: raise IndexError("{}-dimensional boolean indexing is " "not supported. ".format(k.ndim)) @@ -547,8 +548,9 @@ def _validate_indexers(self, key): raise IndexError( "Boolean indexer should be unlabeled or on the " "same dimension to the indexed array. Indexer is " - "on {0:s} but the target dimension is " - "{1:s}.".format(str(k.dims), dim)) + "on {:s} but the target dimension is {:s}." + .format(str(k.dims), dim) + ) def _broadcast_indexes_outer(self, key): dims = tuple(k.dims[0] if isinstance(k, Variable) else dim @@ -888,8 +890,10 @@ def chunk(self, chunks=None, name=None, lock=False): import dask.array as da if utils.is_dict_like(chunks): - chunks = dict((self.get_axis_num(dim), chunk) - for dim, chunk in chunks.items()) + chunks = { + self.get_axis_num(dim): chunk + for dim, chunk in chunks.items() + } if chunks is None: chunks = self.chunks or self.shape diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index d0003b702df..26102a044e3 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -20,7 +20,7 @@ def _infer_line_data(darray, x, y, hue): - error_msg = ('must be either None or one of ({0:s})' + error_msg = ('must be either None or one of ({:s})' .format(', '.join([repr(dd) for dd in darray.dims]))) ndims = len(darray.dims) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index bef9da9fb7f..92f516b8c3b 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -1077,11 +1077,11 @@ def test_open_encodings(self): with open_dataset(tmp_file) as actual: assert_equal(actual['time'], expected['time']) - actual_encoding = dict((k, v) for k, v in - actual['time'].encoding.items() - if k in expected['time'].encoding) - assert actual_encoding == \ - expected['time'].encoding + actual_encoding = { + k: v for k, v in actual['time'].encoding.items() + if k in expected['time'].encoding + } + assert actual_encoding == expected['time'].encoding def test_dump_encodings(self): # regression test for #709 @@ -2870,11 +2870,15 @@ def test_deterministic_names(self): data = create_test_data() data.to_netcdf(tmp) with open_mfdataset(tmp, combine='by_coords') as ds: - original_names = dict((k, v.data.name) - for k, v in ds.data_vars.items()) + original_names = { + k: v.data.name + for k, v in ds.data_vars.items() + } with open_mfdataset(tmp, combine='by_coords') as ds: - repeat_names = dict((k, v.data.name) - for k, v in ds.data_vars.items()) + repeat_names = { + k: v.data.name + for k, v in ds.data_vars.items() + } for var_name, dask_name in original_names.items(): assert var_name in dask_name assert dask_name[:13] == 'open_dataset-' diff --git a/xarray/tests/test_coding_strings.py b/xarray/tests/test_coding_strings.py index 98824c9136c..13c0983212e 100644 --- a/xarray/tests/test_coding_strings.py +++ b/xarray/tests/test_coding_strings.py @@ -35,7 +35,7 @@ def test_vlen_dtype(): def test_EncodedStringCoder_decode(): coder = strings.EncodedStringCoder() - raw_data = np.array([b'abc', 'ß∂µ∆'.encode('utf-8')]) + raw_data = np.array([b'abc', 'ß∂µ∆'.encode()]) raw = Variable(('x',), raw_data, {'_Encoding': 'utf-8'}) actual = coder.decode(raw) @@ -50,7 +50,7 @@ def test_EncodedStringCoder_decode(): def test_EncodedStringCoder_decode_dask(): coder = strings.EncodedStringCoder() - raw_data = np.array([b'abc', 'ß∂µ∆'.encode('utf-8')]) + raw_data = np.array([b'abc', 'ß∂µ∆'.encode()]) raw = Variable(('x',), raw_data, {'_Encoding': 'utf-8'}).chunk() actual = coder.decode(raw) assert isinstance(actual.data, da.Array) diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index dacf68f6be8..82afeab7aba 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -451,7 +451,7 @@ def test_decode_360_day_calendar(): calendar = '360_day' # ensure leap year doesn't matter for year in [2010, 2011, 2012, 2013, 2014]: - units = 'days since {0}-01-01'.format(year) + units = 'days since {}-01-01'.format(year) num_times = np.arange(100) if cftime.__name__ == 'cftime': diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index deed6748761..ff188305c83 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -29,9 +29,14 @@ def test_concat(self): def rectify_dim_order(dataset): # return a new dataset with all variable dimensions transposed into # the order in which they are found in `data` - return Dataset(dict((k, v.transpose(*data[k].dims)) - for k, v in dataset.data_vars.items()), - dataset.coords, attrs=dataset.attrs) + return Dataset( + { + k: v.transpose(*data[k].dims) + for k, v in dataset.data_vars.items() + }, + dataset.coords, + attrs=dataset.attrs + ) for dim in ['dim1', 'dim2']: datasets = [g for _, g in data.groupby(dim, squeeze=False)] diff --git a/xarray/tests/test_conventions.py b/xarray/tests/test_conventions.py index b9690c211f4..e7cb8006b08 100644 --- a/xarray/tests/test_conventions.py +++ b/xarray/tests/test_conventions.py @@ -115,9 +115,9 @@ def test_multidimensional_coordinates(self): foo1_coords = enc['foo1'].attrs.get('coordinates', '') foo2_coords = enc['foo2'].attrs.get('coordinates', '') foo3_coords = enc['foo3'].attrs.get('coordinates', '') - assert set(foo1_coords.split()) == set(['lat1', 'lon1']) - assert set(foo2_coords.split()) == set(['lat2', 'lon2']) - assert set(foo3_coords.split()) == set(['lat3', 'lon3']) + assert set(foo1_coords.split()) == {'lat1', 'lon1'} + assert set(foo2_coords.split()) == {'lat2', 'lon2'} + assert set(foo3_coords.split()) == {'lat3', 'lon3'} # Should not have any global coordinates. assert 'coordinates' not in attrs diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 000469f24bf..3a19c229fe6 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -1685,7 +1685,7 @@ def test_math_with_coords(self): assert_identical(expected, actual) actual = orig + orig[0, 0] - exp_coords = dict((k, v) for k, v in coords.items() if k != 'lat') + exp_coords = {k: v for k, v in coords.items() if k != 'lat'} expected = DataArray(orig.values + orig.values[0, 0], exp_coords, dims=['x', 'y']) assert_identical(expected, actual) @@ -3377,7 +3377,7 @@ def test__title_for_slice(self): assert '' == a2._title_for_slice() def test__title_for_slice_truncate(self): - array = DataArray(np.ones((4))) + array = DataArray(np.ones(4)) array.coords['a'] = 'a' * 100 array.coords['b'] = 'b' * 100 @@ -3773,7 +3773,7 @@ def test_rolling_wrapped_bottleneck(da, name, center, min_periods): # Test all bottleneck functions rolling_obj = da.rolling(time=7, min_periods=min_periods) - func_name = 'move_{0}'.format(name) + func_name = 'move_{}'.format(name) actual = getattr(rolling_obj, name)() expected = getattr(bn, func_name)(da.values, window=7, axis=1, min_count=min_periods) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index fc15393f269..78891045bae 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -136,8 +136,10 @@ def lazy_inaccessible(k, v): data = indexing.LazilyOuterIndexedArray( InaccessibleArray(v.values)) return Variable(v.dims, data, v.attrs) - return dict((k, lazy_inaccessible(k, v)) for - k, v in self._variables.items()) + return { + k: lazy_inaccessible(k, v) + for k, v in self._variables.items() + } class TestDataset: @@ -239,7 +241,7 @@ def test_unicode_data(self): repr(data) # should not raise byteorder = '<' if sys.byteorder == 'little' else '>' - expected = dedent(u"""\ + expected = dedent("""\ Dimensions: (foø: 1) Coordinates: @@ -520,7 +522,7 @@ def test_attr_access(self): assert ds.title == ds.attrs['title'] assert ds.tmin.units == ds['tmin'].attrs['units'] - assert set(['tmin', 'title']) <= set(dir(ds)) + assert {'tmin', 'title'} <= set(dir(ds)) assert 'units' in set(dir(ds.tmin)) # should defer to variable of same name @@ -1953,8 +1955,9 @@ def test_drop_variables(self): assert_identical(data, data.drop([])) - expected = Dataset(dict((k, data[k]) for k in data.variables - if k != 'time')) + expected = Dataset( + {k: data[k] for k in data.variables if k != 'time'} + ) actual = data.drop('time') assert_identical(expected, actual) actual = data.drop(['time']) @@ -2951,9 +2954,9 @@ def test_delitem(self): all_items = set(data.variables) assert set(data.variables) == all_items del data['var1'] - assert set(data.variables) == all_items - set(['var1']) + assert set(data.variables) == all_items - {'var1'} del data['numbers'] - assert set(data.variables) == all_items - set(['var1', 'numbers']) + assert set(data.variables) == all_items - {'var1', 'numbers'} assert 'numbers' not in data.coords expected = Dataset() @@ -2966,8 +2969,12 @@ def test_squeeze(self): for args in [[], [['x']], [['x', 'z']]]: def get_args(v): return [set(args[0]) & set(v.dims)] if args else [] - expected = Dataset(dict((k, v.squeeze(*get_args(v))) - for k, v in data.variables.items())) + expected = Dataset( + { + k: v.squeeze(*get_args(v)) + for k, v in data.variables.items() + } + ) expected = expected.set_coords(data.coords) assert_identical(expected, data.squeeze(*args)) # invalid squeeze @@ -3869,8 +3876,9 @@ def test_reduce(self): assert len(data.mean().coords) == 0 actual = data.max() - expected = Dataset(dict((k, v.max()) - for k, v in data.data_vars.items())) + expected = Dataset( + {k: v.max() for k, v in data.data_vars.items()} + ) assert_equal(expected, actual) assert_equal(data.min(dim=['dim1']), @@ -4981,7 +4989,7 @@ def test_rolling_wrapped_bottleneck(ds, name, center, min_periods, key): # Test all bottleneck functions rolling_obj = ds.rolling(time=7, min_periods=min_periods) - func_name = 'move_{0}'.format(name) + func_name = 'move_{}'.format(name) actual = getattr(rolling_obj, name)() if key == 'z1': # z1 does not depend on 'Time' axis. Stored as it is. expected = ds[key] diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 2a13c131bf3..d6a580048c7 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -45,7 +45,7 @@ def substring_in_axes(substring, ax): ''' Return True if a substring is found anywhere in an axes ''' - alltxt = set([t.get_text() for t in ax.findobj(mpl.text.Text)]) + alltxt = {t.get_text() for t in ax.findobj(mpl.text.Text)} for txt in alltxt: if substring in txt: return True @@ -1158,9 +1158,9 @@ def test_facetgrid_cmap(self): d = DataArray(data, dims=['x', 'y', 'time']) fg = d.plot.pcolormesh(col='time') # check that all color limits are the same - assert len(set(m.get_clim() for m in fg._mappables)) == 1 + assert len({m.get_clim() for m in fg._mappables}) == 1 # check that all colormaps are the same - assert len(set(m.get_cmap().name for m in fg._mappables)) == 1 + assert len({m.get_cmap().name for m in fg._mappables}) == 1 def test_facetgrid_cbar_kwargs(self): a = easy_array((10, 15, 2, 3)) @@ -1498,7 +1498,7 @@ def test_names_appear_somewhere(self): self.darray.name = 'testvar' self.g.map_dataarray(xplt.contourf, 'x', 'y') for k, ax in zip('abc', self.g.axes.flat): - assert 'z = {0}'.format(k) == ax.get_title() + assert 'z = {}'.format(k) == ax.get_title() alltxt = text_in_fig() assert self.darray.name in alltxt diff --git a/xarray/tests/test_sparse.py b/xarray/tests/test_sparse.py index 3aa407f72bc..329952bc064 100644 --- a/xarray/tests/test_sparse.py +++ b/xarray/tests/test_sparse.py @@ -535,7 +535,7 @@ def test_ufuncs(self): def test_dataarray_repr(self): a = xr.DataArray( - COO.from_numpy(np.ones((4))), + COO.from_numpy(np.ones(4)), dims=['x'], coords={'y': ('x', COO.from_numpy(np.arange(4)))}) expected = dedent("""\ @@ -548,7 +548,7 @@ def test_dataarray_repr(self): def test_dataset_repr(self): ds = xr.Dataset( - data_vars={'a': ('x', COO.from_numpy(np.ones((4))))}, + data_vars={'a': ('x', COO.from_numpy(np.ones(4)))}, coords={'y': ('x', COO.from_numpy(np.arange(4)))}) expected = dedent("""\ @@ -562,7 +562,7 @@ def test_dataset_repr(self): def test_dataarray_pickle(self): a1 = xr.DataArray( - COO.from_numpy(np.ones((4))), + COO.from_numpy(np.ones(4)), dims=['x'], coords={'y': ('x', COO.from_numpy(np.arange(4)))}) a2 = pickle.loads(pickle.dumps(a1)) @@ -570,7 +570,7 @@ def test_dataarray_pickle(self): def test_dataset_pickle(self): ds1 = xr.Dataset( - data_vars={'a': ('x', COO.from_numpy(np.ones((4))))}, + data_vars={'a': ('x', COO.from_numpy(np.ones(4)))}, coords={'y': ('x', COO.from_numpy(np.arange(4)))}) ds2 = pickle.loads(pickle.dumps(ds1)) assert_identical(ds1, ds2) diff --git a/xarray/tutorial.py b/xarray/tutorial.py index 01d4f181d7f..0d9009f439d 100644 --- a/xarray/tutorial.py +++ b/xarray/tutorial.py @@ -77,7 +77,7 @@ def open_dataset(name, cache=True, cache_dir=_default_cache_dir, msg = """ MD5 checksum does not match, try downloading dataset again. """ - raise IOError(msg) + raise OSError(msg) ds = _open_dataset(localfile, **kws)